예제 #1
0
def inline_temporaries(expressions, ops):
    """Inline temporaries which could be inlined without blowing up
    the code.

    :arg expressions: a multi-root GEM expression DAG, used for
                      reference counting
    :arg ops: ordered list of Impero terminals
    :returns: a filtered ``ops``, without the unnecessary
              :class:`impero.Evaluate`s
    """
    refcount = collect_refcount(expressions)

    candidates = set()  # candidates for inlining
    for op in ops:
        if isinstance(op, imp.Evaluate):
            expr = op.expression
            if expr.shape == () and refcount[expr] == 1:
                candidates.add(expr)

    # Prevent inlining that pulls expressions into inner loops
    for node in traversal(expressions):
        for child in node.children:
            if child in candidates and set(child.free_indices) < set(node.free_indices):
                candidates.remove(child)

    # Filter out candidates
    return [op for op in ops if not (isinstance(op, imp.Evaluate) and op.expression in candidates)]
예제 #2
0
def inline_temporaries(expressions, ops):
    """Inline temporaries which could be inlined without blowing up
    the code.

    :arg expressions: a multi-root GEM expression DAG, used for
                      reference counting
    :arg ops: ordered list of Impero terminals
    :returns: a filtered ``ops``, without the unnecessary
              :class:`impero.Evaluate`s
    """
    refcount = collect_refcount(expressions)

    candidates = set()  # candidates for inlining
    for op in ops:
        if isinstance(op, imp.Evaluate):
            expr = op.expression
            if expr.shape == () and refcount[expr] == 1:
                candidates.add(expr)

    # Prevent inlining that pulls expressions into inner loops
    for node in traversal(expressions):
        for child in node.children:
            if child in candidates and set(child.free_indices) < set(node.free_indices):
                candidates.remove(child)

    # Filter out candidates
    return [op for op in ops if not (isinstance(op, imp.Evaluate) and op.expression in candidates)]
예제 #3
0
파일: firedrake.py 프로젝트: inducer/tsfc
 def needs_cell_orientations(ir):
     """Does a multi-root GEM expression DAG references cell
     orientations?"""
     for node in traversal(ir):
         if isinstance(node, gem.Variable) and node.name == "cell_orientations":
             return True
     return False
예제 #4
0
def translate_coefficient(terminal, mt, ctx):
    vec = ctx.coefficient(terminal, mt.restriction)

    if terminal.ufl_element().family() == 'Real':
        assert mt.local_derivatives == 0
        return vec

    element = ctx.create_element(terminal.ufl_element())

    # Collect FInAT tabulation for all entities
    per_derivative = collections.defaultdict(list)
    for entity_id in ctx.entity_ids:
        finat_dict = ctx.basis_evaluation(element, mt.local_derivatives, entity_id)
        for alpha, table in finat_dict.items():
            # Filter out irrelevant derivatives
            if sum(alpha) == mt.local_derivatives:
                # A numerical hack that FFC used to apply on FIAT
                # tables still lives on after ditching FFC and
                # switching to FInAT.
                table = ffc_rounding(table, ctx.epsilon)
                per_derivative[alpha].append(table)

    # Merge entity tabulations for each derivative
    if len(ctx.entity_ids) == 1:
        def take_singleton(xs):
            x, = xs  # asserts singleton
            return x
        per_derivative = {alpha: take_singleton(tables)
                          for alpha, tables in per_derivative.items()}
    else:
        f = ctx.entity_number(mt.restriction)
        per_derivative = {alpha: gem.select_expression(tables, f)
                          for alpha, tables in per_derivative.items()}

    # Coefficient evaluation
    ctx.index_cache.setdefault(terminal.ufl_element(), element.get_indices())
    beta = ctx.index_cache[terminal.ufl_element()]
    zeta = element.get_value_indices()
    vec_beta, = gem.optimise.remove_componenttensors([gem.Indexed(vec, beta)])
    value_dict = {}
    for alpha, table in per_derivative.items():
        table_qi = gem.Indexed(table, beta + zeta)
        summands = []
        for var, expr in unconcatenate([(vec_beta, table_qi)], ctx.index_cache):
            value = gem.IndexSum(gem.Product(expr, var), var.index_ordering())
            summands.append(gem.optimise.contraction(value))
        optimised_value = gem.optimise.make_sum(summands)
        value_dict[alpha] = gem.ComponentTensor(optimised_value, zeta)

    # Change from FIAT to UFL arrangement
    result = fiat_to_ufl(value_dict, mt.local_derivatives)
    assert result.shape == mt.expr.ufl_shape
    assert set(result.free_indices) <= set(ctx.point_indices)

    # Detect Jacobian of affine cells
    if not result.free_indices and all(numpy.count_nonzero(node.array) <= 2
                                       for node in traversal((result,))
                                       if isinstance(node, gem.Literal)):
        result = gem.optimise.aggressive_unroll(result)
    return result
예제 #5
0
 def needs_cell_orientations(ir):
     """Does a multi-root GEM expression DAG references cell
     orientations?"""
     for node in traversal(ir):
         if isinstance(node, gem.Variable) and node.name == "cell_orientations":
             return True
     return False
예제 #6
0
def optimise_expressions(expressions, argument_indices):
    """Perform loop optimisations on GEM DAGs

    :arg expressions: list of GEM DAGs
    :arg argument_indices: tuple of argument indices

    :returns: list of optimised GEM DAGs
    """
    # Skip optimisation for if Failure node is present
    for n in traversal(expressions):
        if isinstance(n, Failure):
            return expressions

    def classify(argument_indices, expression):
        n = len(argument_indices.intersection(expression.free_indices))
        if n == 0:
            return OTHER
        elif n == 1:
            if isinstance(expression, Indexed):
                return ATOMIC
            else:
                return COMPOUND
        else:
            return COMPOUND

    # Apply argument factorisation unconditionally
    classifier = partial(classify, set(argument_indices))
    monomial_sums = collect_monomials(expressions, classifier)
    return [optimise_monomial_sum(ms, argument_indices) for ms in monomial_sums]
예제 #7
0
def compile_gem(return_variables, expressions, prefix_ordering, remove_zeros=False):
    """Compiles GEM to Impero.

    :arg return_variables: return variables for each root (type: GEM expressions)
    :arg expressions: multi-root expression DAG (type: GEM expressions)
    :arg prefix_ordering: outermost loop indices
    :arg remove_zeros: remove zero assignment to return variables
    """
    expressions = optimise.remove_componenttensors(expressions)

    # Remove zeros
    if remove_zeros:
        rv = []
        es = []
        for var, expr in zip(return_variables, expressions):
            if not isinstance(expr, gem.Zero):
                rv.append(var)
                es.append(expr)
        return_variables, expressions = rv, es

    # Collect indices in a deterministic order
    indices = OrderedSet()
    for node in traversal(expressions):
        if isinstance(node, gem.Indexed):
            for index in node.multiindex:
                if isinstance(index, gem.Index):
                    indices.add(index)
        elif isinstance(node, gem.FlexiblyIndexed):
            for offset, idxs in node.dim2idxs:
                for index, stride in idxs:
                    if isinstance(index, gem.Index):
                        indices.add(index)

    # Build ordered index map
    index_ordering = make_prefix_ordering(indices, prefix_ordering)
    apply_ordering = make_index_orderer(index_ordering)

    get_indices = lambda expr: apply_ordering(expr.free_indices)

    # Build operation ordering
    ops = scheduling.emit_operations(list(zip(return_variables, expressions)), get_indices)

    # Empty kernel
    if len(ops) == 0:
        raise NoopError()

    # Drop unnecessary temporaries
    ops = inline_temporaries(expressions, ops)

    # Build Impero AST
    tree = make_loop_tree(ops, get_indices)

    # Collect temporaries
    temporaries = collect_temporaries(ops)

    # Determine declarations
    declare, indices = place_declarations(ops, tree, temporaries, get_indices)

    # Prepare ImperoC (Impero AST + other data for code generation)
    return ImperoC(tree, temporaries, declare, indices)
예제 #8
0
def compile_gem(assignments, prefix_ordering, remove_zeros=False):
    """Compiles GEM to Impero.

    :arg assignments: list of (return variable, expression DAG root) pairs
    :arg prefix_ordering: outermost loop indices
    :arg remove_zeros: remove zero assignment to return variables
    """
    # Remove zeros
    if remove_zeros:

        def nonzero(assignment):
            variable, expression = assignment
            return not isinstance(expression, gem.Zero)

        assignments = list(filter(nonzero, assignments))

    # Just the expressions
    expressions = [expression for variable, expression in assignments]

    # Collect indices in a deterministic order
    indices = OrderedSet()
    for node in traversal(expressions):
        if isinstance(node, gem.Indexed):
            for index in node.multiindex:
                if isinstance(index, gem.Index):
                    indices.add(index)
        elif isinstance(node, gem.FlexiblyIndexed):
            for offset, idxs in node.dim2idxs:
                for index, stride in idxs:
                    if isinstance(index, gem.Index):
                        indices.add(index)

    # Build ordered index map
    index_ordering = make_prefix_ordering(indices, prefix_ordering)
    apply_ordering = make_index_orderer(index_ordering)

    get_indices = lambda expr: apply_ordering(expr.free_indices)

    # Build operation ordering
    ops = scheduling.emit_operations(assignments, get_indices)

    # Empty kernel
    if len(ops) == 0:
        raise NoopError()

    # Drop unnecessary temporaries
    ops = inline_temporaries(expressions, ops)

    # Build Impero AST
    tree = make_loop_tree(ops, get_indices)

    # Collect temporaries
    temporaries = collect_temporaries(tree)

    # Determine declarations
    declare, indices = place_declarations(tree, temporaries, get_indices)

    # Prepare ImperoC (Impero AST + other data for code generation)
    return ImperoC(tree, temporaries, declare, indices)
예제 #9
0
def compile_gem(return_variables, expressions, prefix_ordering, remove_zeros=False):
    """Compiles GEM to Impero.

    :arg return_variables: return variables for each root (type: GEM expressions)
    :arg expressions: multi-root expression DAG (type: GEM expressions)
    :arg prefix_ordering: outermost loop indices
    :arg remove_zeros: remove zero assignment to return variables
    """
    # Remove zeros
    if remove_zeros:
        rv = []
        es = []
        for var, expr in zip(return_variables, expressions):
            if not isinstance(expr, gem.Zero):
                rv.append(var)
                es.append(expr)
        return_variables, expressions = rv, es

    # Collect indices in a deterministic order
    indices = OrderedSet()
    for node in traversal(expressions):
        if isinstance(node, gem.Indexed):
            for index in node.multiindex:
                if isinstance(index, gem.Index):
                    indices.add(index)
        elif isinstance(node, gem.FlexiblyIndexed):
            for offset, idxs in node.dim2idxs:
                for index, stride in idxs:
                    if isinstance(index, gem.Index):
                        indices.add(index)

    # Build ordered index map
    index_ordering = make_prefix_ordering(indices, prefix_ordering)
    apply_ordering = make_index_orderer(index_ordering)

    get_indices = lambda expr: apply_ordering(expr.free_indices)

    # Build operation ordering
    ops = scheduling.emit_operations(list(zip(return_variables, expressions)), get_indices)

    # Empty kernel
    if len(ops) == 0:
        raise NoopError()

    # Drop unnecessary temporaries
    ops = inline_temporaries(expressions, ops)

    # Build Impero AST
    tree = make_loop_tree(ops, get_indices)

    # Collect temporaries
    temporaries = collect_temporaries(tree)

    # Determine declarations
    declare, indices = place_declarations(tree, temporaries, get_indices)

    # Prepare ImperoC (Impero AST + other data for code generation)
    return ImperoC(tree, temporaries, declare, indices)
예제 #10
0
    def classify(atomics_set, expression):
        if expression in atomics_set:
            return ATOMIC

        for node in traversal([expression]):
            if node in atomics_set:
                return COMPOUND

        return OTHER
예제 #11
0
    def classify(atomics_set, expression):
        if expression in atomics_set:
            return ATOMIC

        for node in traversal([expression]):
            if node in atomics_set:
                return COMPOUND

        return OTHER
예제 #12
0
def collect_temporaries(tree):
    """Collects GEM expressions to assign to temporaries from a list
    of Impero terminals."""
    result = []
    for node in traversal((tree,)):
        # IndexSum temporaries should be added either at Initialise or
        # at Accumulate.  The difference is only in ordering
        # (numbering).  We chose Accumulate here.
        if isinstance(node, imp.Accumulate):
            result.append(node.indexsum)
        elif isinstance(node, imp.Evaluate):
            result.append(node.expression)
    return result
예제 #13
0
def collect_temporaries(tree):
    """Collects GEM expressions to assign to temporaries from a list
    of Impero terminals."""
    result = []
    for node in traversal((tree,)):
        # IndexSum temporaries should be added either at Initialise or
        # at Accumulate.  The difference is only in ordering
        # (numbering).  We chose Accumulate here.
        if isinstance(node, imp.Accumulate):
            result.append(node.indexsum)
        elif isinstance(node, imp.Evaluate):
            result.append(node.expression)
    return result
예제 #14
0
def compile_gem(assignments, prefix_ordering, remove_zeros=False):
    """Compiles GEM to Impero.

    :arg assignments: list of (return variable, expression DAG root) pairs
    :arg prefix_ordering: outermost loop indices
    :arg remove_zeros: remove zero assignment to return variables
    """
    # Remove zeros
    if remove_zeros:
        def nonzero(assignment):
            variable, expression = assignment
            return not isinstance(expression, gem.Zero)
        assignments = list(filter(nonzero, assignments))

    # Just the expressions
    expressions = [expression for variable, expression in assignments]

    # Collect indices in a deterministic order
    indices = list(collections.OrderedDict.fromkeys(chain.from_iterable(
        node.index_ordering()
        for node in traversal(expressions)
        if isinstance(node, (gem.Indexed, gem.FlexiblyIndexed))
    )))

    # Build ordered index map
    index_ordering = make_prefix_ordering(indices, prefix_ordering)
    apply_ordering = make_index_orderer(index_ordering)

    get_indices = lambda expr: apply_ordering(expr.free_indices)

    # Build operation ordering
    ops = scheduling.emit_operations(assignments, get_indices)

    # Empty kernel
    if len(ops) == 0:
        raise NoopError()

    # Drop unnecessary temporaries
    ops = inline_temporaries(expressions, ops)

    # Build Impero AST
    tree = make_loop_tree(ops, get_indices)

    # Collect temporaries
    temporaries = collect_temporaries(tree)

    # Determine declarations
    declare, indices = place_declarations(tree, temporaries, get_indices)

    # Prepare ImperoC (Impero AST + other data for code generation)
    return ImperoC(tree, temporaries, declare, indices)
예제 #15
0
def check_requirements(ir):
    """Look for cell orientations, cell sizes, and collect tabulations
    in one pass."""
    cell_orientations = False
    cell_sizes = False
    rt_tabs = {}
    for node in traversal(ir):
        if isinstance(node, gem.Variable):
            if node.name == "cell_orientations":
                cell_orientations = True
            elif node.name == "cell_sizes":
                cell_sizes = True
            elif node.name.startswith("rt_"):
                rt_tabs[node.name] = node.shape
    return cell_orientations, cell_sizes, tuple(sorted(rt_tabs.items()))
예제 #16
0
파일: optimise.py 프로젝트: tj-sun/tsfc
def contraction(expression):
    """Optimise the contractions of the tensor product at the root of
    the expression, including:

    - IndexSum-Delta cancellation
    - Sum factorisation

    This routine was designed with finite element coefficient
    evaluation in mind.
    """
    # Eliminate annoying ComponentTensors
    expression, = remove_componenttensors([expression])

    # Flatten product tree, eliminate deltas, sum factorise
    def rebuild(expression):
        sum_indices, factors = delta_elimination(*traverse_product(expression))
        factors = remove_componenttensors(factors)
        return sum_factorise(sum_indices, factors)

    # Sometimes the value shape is composed as a ListTensor, which
    # could get in the way of decomposing factors.  In particular,
    # this is the case for H(div) and H(curl) conforming tensor
    # product elements.  So if ListTensors are used, they are pulled
    # out to be outermost, so we can straightforwardly factorise each
    # of its entries.
    lt_fis = OrderedDict()  # ListTensor free indices
    for node in traversal((expression, )):
        if isinstance(node, Indexed):
            child, = node.children
            if isinstance(child, ListTensor):
                lt_fis.update(zip_longest(node.multiindex, ()))
    lt_fis = tuple(index for index in lt_fis
                   if index in expression.free_indices)

    if lt_fis:
        # Rebuild each split component
        tensor = ComponentTensor(expression, lt_fis)
        entries = [
            Indexed(tensor, zeta) for zeta in numpy.ndindex(tensor.shape)
        ]
        entries = remove_componenttensors(entries)
        return Indexed(
            ListTensor(
                numpy.array(list(map(rebuild,
                                     entries))).reshape(tensor.shape)), lt_fis)
    else:
        # Rebuild whole expression at once
        return rebuild(expression)
예제 #17
0
파일: coffee_mode.py 프로젝트: inducer/tsfc
def optimise_expressions(expressions, argument_indices):
    """Perform loop optimisations on GEM DAGs

    :arg expressions: list of GEM DAGs
    :arg argument_indices: tuple of argument indices

    :returns: list of optimised GEM DAGs
    """
    # Skip optimisation for if Failure node is present
    for n in traversal(expressions):
        if isinstance(n, Failure):
            return expressions

    # Apply argument factorisation unconditionally
    classifier = partial(spectral.classify, set(argument_indices))
    monomial_sums = collect_monomials(expressions, classifier)
    return [optimise_monomial_sum(ms, argument_indices) for ms in monomial_sums]
예제 #18
0
def optimise_expressions(expressions, argument_indices):
    """Perform loop optimisations on GEM DAGs

    :arg expressions: list of GEM DAGs
    :arg argument_indices: tuple of argument indices

    :returns: list of optimised GEM DAGs
    """
    # Skip optimisation for if Failure node is present
    for n in traversal(expressions):
        if isinstance(n, Failure):
            return expressions

    # Apply argument factorisation unconditionally
    classifier = partial(spectral.classify, set(argument_indices))
    monomial_sums = collect_monomials(expressions, classifier)
    return [
        optimise_monomial_sum(ms, argument_indices) for ms in monomial_sums
    ]
예제 #19
0
파일: refactorise.py 프로젝트: peide/tsfc
def collect_monomials(expressions, classifier):
    """Refactorises expressions into a sum-of-products form, using
    distributivity rules (i.e. a*(b + c) -> a*b + a*c).  Expansion
    proceeds until all "compound" expressions are broken up.

    :arg expressions: GEM expressions to refactorise
    :arg classifier: a function that can classify any GEM expression
                     as ``ATOMIC``, ``COMPOUND``, or ``OTHER``.  This
                     classification drives the factorisation.

    :returns: list of :py:class:`MonomialSum`s

    :raises FactorisationError: Failed to break up some "compound"
                                expressions with expansion.
    """
    # Get ComponentTensors out of the way
    expressions = remove_componenttensors(expressions)

    # Get ListTensors out of the way
    must_unroll = []  # indices to unroll
    for node in traversal(expressions):
        if isinstance(node, Indexed):
            child, = node.children
            if isinstance(child, ListTensor) and classifier(node) == COMPOUND:
                must_unroll.extend(node.multiindex)
    if must_unroll:
        must_unroll = set(must_unroll)
        expressions = unroll_indexsum(expressions,
                                      predicate=lambda i: i in must_unroll)
        expressions = remove_componenttensors(expressions)

    # Expand Conditional nodes which are COMPOUND
    conditional_predicate = lambda node: classifier(node) == COMPOUND
    expressions = expand_conditional(expressions, conditional_predicate)

    # Finally, refactorise expressions
    mapper = Memoizer(_collect_monomials)
    mapper.classifier = classifier
    mapper.rename_map = make_rename_map()
    return list(map(mapper, expressions))
예제 #20
0
파일: refactorise.py 프로젝트: inducer/tsfc
def collect_monomials(expressions, classifier):
    """Refactorises expressions into a sum-of-products form, using
    distributivity rules (i.e. a*(b + c) -> a*b + a*c).  Expansion
    proceeds until all "compound" expressions are broken up.

    :arg expressions: GEM expressions to refactorise
    :arg classifier: a function that can classify any GEM expression
                     as ``ATOMIC``, ``COMPOUND``, or ``OTHER``.  This
                     classification drives the factorisation.

    :returns: list of :py:class:`MonomialSum`s

    :raises FactorisationError: Failed to break up some "compound"
                                expressions with expansion.
    """
    # Get ComponentTensors out of the way
    expressions = remove_componenttensors(expressions)

    # Get ListTensors out of the way
    must_unroll = []  # indices to unroll
    for node in traversal(expressions):
        if isinstance(node, Indexed):
            child, = node.children
            if isinstance(child, ListTensor) and classifier(node) == COMPOUND:
                must_unroll.extend(node.multiindex)
    if must_unroll:
        must_unroll = set(must_unroll)
        expressions = unroll_indexsum(expressions,
                                      predicate=lambda i: i in must_unroll)
        expressions = remove_componenttensors(expressions)

    # Expand Conditional nodes which are COMPOUND
    conditional_predicate = lambda node: classifier(node) == COMPOUND
    expressions = expand_conditional(expressions, conditional_predicate)

    # Finally, refactorise expressions
    mapper = Memoizer(_collect_monomials)
    mapper.classifier = classifier
    mapper.rename_map = make_rename_map()
    return list(map(mapper, expressions))
예제 #21
0
def place_declarations(tree, temporaries, get_indices):
    """Determines where and how to declare temporaries for an Impero AST.

    :arg tree: Impero AST to determine the declarations for
    :arg temporaries: list of GEM expressions which are assigned to
                      temporaries
    :arg get_indices: callable mapping from GEM nodes to an ordering
                      of free indices
    """
    numbering = {t: n for n, t in enumerate(temporaries)}
    assert len(numbering) == len(temporaries)

    # Collect the total number of temporary references
    total_refcount = collections.Counter()
    for node in traversal((tree,)):
        if isinstance(node, imp.Terminal):
            total_refcount.update(temp_refcount(numbering, node))
    assert set(total_refcount) == set(temporaries)

    # Result
    declare = {}
    indices = {}

    @singledispatch
    def recurse(expr, loop_indices):
        """Visit an Impero AST to collect declarations.

        :arg expr: Impero tree node
        :arg loop_indices: loop indices (in order) from the outer
                           loops surrounding ``expr``
        :returns: :class:`collections.Counter` with the reference
                  counts for each temporary in the subtree whose root
                  is ``expr``
        """
        return AssertionError("unsupported expression type %s" % type(expr))

    @recurse.register(imp.Terminal)
    def recurse_terminal(expr, loop_indices):
        return temp_refcount(numbering, expr)

    @recurse.register(imp.For)
    def recurse_for(expr, loop_indices):
        return recurse(expr.children[0], loop_indices + (expr.index,))

    @recurse.register(imp.Block)
    def recurse_block(expr, loop_indices):
        # Temporaries declared at the beginning of the block are
        # collected here
        declare[expr] = []

        # Collect reference counts for the block
        refcount = collections.Counter()
        for statement in expr.children:
            refcount.update(recurse(statement, loop_indices))

        # Visit :class:`collections.Counter` in deterministic order
        for e in sorted(refcount.keys(), key=lambda t: numbering[t]):
            if refcount[e] == total_refcount[e]:
                # If all references are within this block, then this
                # block is the right place to declare the temporary.
                assert loop_indices == get_indices(e)[:len(loop_indices)]
                indices[e] = get_indices(e)[len(loop_indices):]
                if indices[e]:
                    # Scalar-valued temporaries are not declared until
                    # their value is assigned.  This does not really
                    # matter, but produces a more compact and nicer to
                    # read C code.
                    declare[expr].append(e)
                # Remove expression from the ``refcount`` so it will
                # not be declared again.
                del refcount[e]
        return refcount

    # Populate result
    remainder = recurse(tree, ())
    assert not remainder

    # Set in ``declare`` for Impero terminals whether they should
    # declare the temporary that they are writing to.
    for node in traversal((tree,)):
        if isinstance(node, imp.Terminal):
            declare[node] = False
            if isinstance(node, imp.Evaluate):
                e = node.expression
            elif isinstance(node, imp.Initialise):
                e = node.indexsum
            else:
                continue

            if len(indices[e]) == 0:
                declare[node] = True

    return declare, indices
예제 #22
0
def place_declarations(tree, temporaries, get_indices):
    """Determines where and how to declare temporaries for an Impero AST.

    :arg tree: Impero AST to determine the declarations for
    :arg temporaries: list of GEM expressions which are assigned to
                      temporaries
    :arg get_indices: callable mapping from GEM nodes to an ordering
                      of free indices
    """
    numbering = {t: n for n, t in enumerate(temporaries)}
    assert len(numbering) == len(temporaries)

    # Collect the total number of temporary references
    total_refcount = collections.Counter()
    for node in traversal((tree,)):
        if isinstance(node, imp.Terminal):
            total_refcount.update(temp_refcount(numbering, node))
    assert set(total_refcount) == set(temporaries)

    # Result
    declare = {}
    indices = {}

    @singledispatch
    def recurse(expr, loop_indices):
        """Visit an Impero AST to collect declarations.

        :arg expr: Impero tree node
        :arg loop_indices: loop indices (in order) from the outer
                           loops surrounding ``expr``
        :returns: :class:`collections.Counter` with the reference
                  counts for each temporary in the subtree whose root
                  is ``expr``
        """
        return AssertionError("unsupported expression type %s" % type(expr))

    @recurse.register(imp.Terminal)
    def recurse_terminal(expr, loop_indices):
        return temp_refcount(numbering, expr)

    @recurse.register(imp.For)
    def recurse_for(expr, loop_indices):
        return recurse(expr.children[0], loop_indices + (expr.index,))

    @recurse.register(imp.Block)
    def recurse_block(expr, loop_indices):
        # Temporaries declared at the beginning of the block are
        # collected here
        declare[expr] = []

        # Collect reference counts for the block
        refcount = collections.Counter()
        for statement in expr.children:
            refcount.update(recurse(statement, loop_indices))

        # Visit :class:`collections.Counter` in deterministic order
        for e in sorted(refcount.keys(), key=lambda t: numbering[t]):
            if refcount[e] == total_refcount[e]:
                # If all references are within this block, then this
                # block is the right place to declare the temporary.
                assert loop_indices == get_indices(e)[:len(loop_indices)]
                indices[e] = get_indices(e)[len(loop_indices):]
                if indices[e]:
                    # Scalar-valued temporaries are not declared until
                    # their value is assigned.  This does not really
                    # matter, but produces a more compact and nicer to
                    # read C code.
                    declare[expr].append(e)
                # Remove expression from the ``refcount`` so it will
                # not be declared again.
                del refcount[e]
        return refcount

    # Populate result
    remainder = recurse(tree, ())
    assert not remainder

    # Set in ``declare`` for Impero terminals whether they should
    # declare the temporary that they are writing to.
    for node in traversal((tree,)):
        if isinstance(node, imp.Terminal):
            declare[node] = False
            if isinstance(node, imp.Evaluate):
                e = node.expression
            elif isinstance(node, imp.Initialise):
                e = node.indexsum
            else:
                continue

            if len(indices[e]) == 0:
                declare[node] = True

    return declare, indices