Esempio n. 1
0
def propagate_memlet(dfg_state,
                     memlet: Memlet,
                     scope_node: nodes.EntryNode,
                     union_inner_edges: bool,
                     arr=None):
    """ Tries to propagate a memlet through a scope (computes the image of 
        the memlet function applied on an integer set of, e.g., a map range) 
        and returns a new memlet object.
        :param dfg_state: An SDFGState object representing the graph.
        :param memlet: The memlet adjacent to the scope node from the inside.
        :param scope_node: A scope entry or exit node.
        :param union_inner_edges: True if the propagation should take other
                                  neighboring internal memlets within the same
                                  scope into account.
    """
    if isinstance(scope_node, nodes.EntryNode):
        entry_node = scope_node
        neighboring_edges = dfg_state.out_edges(scope_node)
    elif isinstance(scope_node, nodes.ExitNode):
        entry_node = dfg_state.scope_dict()[scope_node]
        neighboring_edges = dfg_state.in_edges(scope_node)
    else:
        raise TypeError('Trying to propagate through a non-scope node')
    if memlet.is_empty():
        return Memlet()

    sdfg = dfg_state.parent
    scope_node_symbols = set(conn for conn in entry_node.in_connectors
                             if not conn.startswith('IN_'))
    defined_vars = [
        symbolic.pystr_to_symbolic(s)
        for s in dfg_state.symbols_defined_at(entry_node).keys()
        if s not in scope_node_symbols
    ]

    # Find other adjacent edges within the connected to the scope node
    # and union their subsets
    if union_inner_edges:
        aggdata = [
            e.data for e in neighboring_edges
            if e.data.data == memlet.data and e.data != memlet
        ]
    else:
        aggdata = []

    aggdata.append(memlet)

    if arr is None:
        if memlet.data not in sdfg.arrays:
            raise KeyError('Data descriptor (Array, Stream) "%s" not defined '
                           'in SDFG.' % memlet.data)
        arr = sdfg.arrays[memlet.data]

    # Propagate subset
    if isinstance(entry_node, nodes.MapEntry):
        mapnode = entry_node.map
        return propagate_subset(aggdata, arr, mapnode.params, mapnode.range,
                                defined_vars)

    elif isinstance(entry_node, nodes.ConsumeEntry):
        # Nothing to analyze/propagate in consume
        new_memlet = copy.copy(memlet)
        new_memlet.subset = subsets.Range.from_array(arr)
        new_memlet.other_subset = None
        new_memlet.volume = 0
        new_memlet.dynamic = True
        return new_memlet
    else:
        raise NotImplementedError('Unimplemented primitive: %s' %
                                  type(entry_node))
Esempio n. 2
0
def propagate_memlet(dfg_state,
                     memlet: Memlet,
                     scope_node: nodes.EntryNode,
                     union_inner_edges: bool,
                     arr=None):
    """ Tries to propagate a memlet through a scope (computes the image of 
        the memlet function applied on an integer set of, e.g., a map range) 
        and returns a new memlet object.
        :param dfg_state: An SDFGState object representing the graph.
        :param memlet: The memlet adjacent to the scope node from the inside.
        :param scope_node: A scope entry or exit node.
        :param union_inner_edges: True if the propagation should take other
                                  neighboring internal memlets within the same
                                  scope into account.
    """
    if isinstance(scope_node, nodes.EntryNode):
        entry_node = scope_node
        neighboring_edges = dfg_state.out_edges(scope_node)
    elif isinstance(scope_node, nodes.ExitNode):
        entry_node = dfg_state.scope_dict()[scope_node]
        neighboring_edges = dfg_state.in_edges(scope_node)
    else:
        raise TypeError('Trying to propagate through a non-scope node')
    if memlet.is_empty():
        return Memlet()

    sdfg = dfg_state.parent
    scope_node_symbols = set(conn for conn in entry_node.in_connectors
                             if not conn.startswith('IN_'))
    defined_vars = [
        symbolic.pystr_to_symbolic(s)
        for s in dfg_state.symbols_defined_at(entry_node).keys()
        if s not in scope_node_symbols
    ]

    # Find other adjacent edges within the connected to the scope node
    # and union their subsets
    if union_inner_edges:
        aggdata = [
            e.data for e in neighboring_edges
            if e.data.data == memlet.data and e.data != memlet
        ]
    else:
        aggdata = []

    aggdata.append(memlet)

    if arr is None:
        if memlet.data not in sdfg.arrays:
            raise KeyError('Data descriptor (Array, Stream) "%s" not defined '
                           'in SDFG.' % memlet.data)
        arr = sdfg.arrays[memlet.data]

    # Propagate subset
    if isinstance(entry_node, nodes.MapEntry):
        mapnode = entry_node.map

        variable_context = [
            defined_vars,
            [symbolic.pystr_to_symbolic(p) for p in mapnode.params]
        ]

        new_subset = None
        for md in aggdata:
            tmp_subset = None
            for pclass in MemletPattern.extensions():
                pattern = pclass()
                if pattern.match([md.subset], variable_context, mapnode.range,
                                 [md]):
                    tmp_subset = pattern.propagate(arr, [md.subset],
                                                   mapnode.range)
                    break
            else:
                # No patterns found. Emit a warning and propagate the entire
                # array
                warnings.warn('Cannot find appropriate memlet pattern to '
                              'propagate %s through %s' %
                              (str(md.subset), str(mapnode.range)))
                tmp_subset = subsets.Range.from_array(arr)

            # Union edges as necessary
            if new_subset is None:
                new_subset = tmp_subset
            else:
                old_subset = new_subset
                new_subset = subsets.union(new_subset, tmp_subset)
                if new_subset is None:
                    warnings.warn('Subset union failed between %s and %s ' %
                                  (old_subset, tmp_subset))
                    break

        # Some unions failed
        if new_subset is None:
            new_subset = subsets.Range.from_array(arr)

        assert new_subset is not None

    elif isinstance(entry_node, nodes.ConsumeEntry):
        # Nothing to analyze/propagate in consume
        new_subset = subsets.Range.from_array(arr)
    else:
        raise NotImplementedError('Unimplemented primitive: %s' %
                                  type(entry_node))
    ### End of subset propagation

    new_memlet = copy.copy(memlet)
    new_memlet.subset = new_subset
    new_memlet.other_subset = None

    # Number of accesses in the propagated memlet is the sum of the internal
    # number of accesses times the size of the map range set (unbounded dynamic)
    new_memlet.num_accesses = (
        sum(m.num_accesses for m in aggdata) *
        functools.reduce(lambda a, b: a * b, scope_node.map.range.size(), 1))
    if any(m.dynamic for m in aggdata):
        new_memlet.dynamic = True
    elif symbolic.issymbolic(new_memlet.num_accesses) and any(
            s not in defined_vars
            for s in new_memlet.num_accesses.free_symbols):
        new_memlet.dynamic = True
        new_memlet.num_accesses = 0

    return new_memlet