Exemplo n.º 1
0
def promote_scalar_expressions(exprs, shape, indices, onstack):
    """
    Transform a collection of scalar expressions into tensor expressions.
    """
    processed = []

    # Fist promote the LHS
    graph = temporaries_graph(exprs)
    mapper = {}
    for k, v in graph.items():
        if v.is_scalar:
            # Create a new function symbol
            data = TensorFunction(name=k.name,
                                  shape=shape,
                                  dimensions=indices,
                                  onstack=onstack)
            indexed = Indexed(data.indexed, *indices)
            mapper[k] = indexed
            processed.append(Eq(indexed, v.rhs))
        else:
            processed.append(Eq(k, v.rhs))

    # Propagate the transformed LHS through the expressions
    processed = [Eq(n.lhs, n.rhs.xreplace(mapper)) for n in processed]

    return processed
Exemplo n.º 2
0
def clusterize(exprs, stencils):
    """Derive :class:`Cluster`s from an iterator of expressions; a stencil for
    each expression must be provided."""
    assert len(exprs) == len(stencils)

    exprs, stencils = aggregate(exprs, stencils)

    g = temporaries_graph(exprs)
    mapper = OrderedDict([(i.lhs, j) for i, j in zip(g.values(), stencils)
                          if i.is_tensor])

    clusters = []
    for k, v in mapper.items():
        # Determine what temporaries are needed to compute /i/
        exprs = g.trace(k)

        # Determine the Stencil of the cluster
        stencil = Stencil(v.entries)
        for i in exprs:
            stencil = stencil.add(mapper.get(i.lhs, {}))
        stencil = stencil.frozen

        # Drop all non-output tensors, as computed by other clusters
        exprs = [i for i in exprs if i.lhs.is_Symbol or i.lhs == k]

        # Create and track the cluster
        clusters.append(Cluster(exprs, stencil))

    return merge(clusters)
Exemplo n.º 3
0
 def reschedule(self, exprs):
     """
     Build a new cluster with expressions ``exprs`` having same stencil
     as ``self``. The order of the expressions in the new cluster is such that
     self's ordering is honored.
     """
     g = temporaries_graph(exprs)
     exprs = g.reschedule(self.exprs)
     return Cluster(exprs, self.stencil, self.atomics)
Exemplo n.º 4
0
def clusterize(exprs, stencils, atomics=None):
    """
    Derive :class:`Cluster` objetcs from an iterator of expressions; a stencil for
    each expression must be provided. A list of atomic dimensions (see description
    in Cluster.__doc__) may be provided.
    """
    assert len(exprs) == len(stencils)

    exprs, stencils = aggregate(exprs, stencils)

    Info = namedtuple('Info', 'trace stencil')

    g = temporaries_graph(exprs)
    mapper = OrderedDict([
        (k, Info(g.trace(k) + g.trace(k, readby=True, strict=True), j))
        for (k, v), j in zip(g.items(), stencils) if v.is_tensor
    ])

    # A cluster stencil is determined iteratively, by first calculating the
    # "local" stencil and then by looking at the stencils of all other clusters
    # depending on it. The stencil information is propagated until there are
    # no more updates.
    queue = list(mapper)
    while queue:
        target = queue.pop(0)

        info = mapper[target]
        strict_trace = [i.lhs for i in info.trace if i.lhs != target]

        stencil = Stencil(info.stencil.entries)
        for i in strict_trace:
            if i in mapper:
                stencil = stencil.add(mapper[i].stencil)

        mapper[target] = Info(info.trace, stencil)

        if stencil != info.stencil:
            # Something has changed, need to propagate the update
            queue.extend([i for i in strict_trace if i not in queue])

    clusters = []
    for target, info in mapper.items():
        # Drop all non-output tensors, as computed by other clusters
        exprs = [i for i in info.trace if i.lhs.is_Symbol or i.lhs == target]

        # Create and track the cluster
        clusters.append(Cluster(exprs, info.stencil.frozen, atomics))

    return merge(clusters)
Exemplo n.º 5
0
def compact_temporaries(exprs):
    """
    Drop temporaries consisting of single symbols.
    """
    g = temporaries_graph(exprs)

    mapper = {list(v.reads)[0]: k for k, v in g.items() if v.is_dead}

    processed = []
    for k, v in g.items():
        if k in mapper:
            processed.append(Eq(mapper[k], v.rhs))
        elif not v.is_dead:
            processed.append(v.xreplace(mapper))

    return processed
Exemplo n.º 6
0
def compact_temporaries(exprs):
    """
    Drop temporaries consisting of single symbols.
    """
    g = temporaries_graph(exprs)

    mapper = {
        k: v.rhs
        for k, v in g.items()
        if v.is_scalar and (q_leaf(v.rhs) or v.rhs.is_Function)
    }

    processed = []
    for k, v in g.items():
        if k not in mapper:
            # The temporary /v/ is retained, and substitutions may be applied
            handle, _ = xreplace_constrained(v, mapper, repeat=True)
            assert len(handle) == 1
            processed.extend(handle)

    return processed
Exemplo n.º 7
0
 def __init__(self, exprs, stencil, atomics):
     self.trace = temporaries_graph(exprs)
     self.stencil = stencil
     self.atomics = as_tuple(atomics)
Exemplo n.º 8
0
 def __init__(self, exprs, stencil):
     self.trace = temporaries_graph(exprs)
     self.stencil = stencil