def inline_temporaries(expressions, ops): """Inline temporaries which could be inlined without blowing up the code. :arg expressions: a multi-root GEM expression DAG, used for reference counting :arg ops: ordered list of Impero terminals :returns: a filtered ``ops``, without the unnecessary :class:`impero.Evaluate`s """ refcount = collect_refcount(expressions) candidates = set() # candidates for inlining for op in ops: if isinstance(op, imp.Evaluate): expr = op.expression if expr.shape == () and refcount[expr] == 1: candidates.add(expr) # Prevent inlining that pulls expressions into inner loops for node in traversal(expressions): for child in node.children: if child in candidates and set(child.free_indices) < set(node.free_indices): candidates.remove(child) # Filter out candidates return [op for op in ops if not (isinstance(op, imp.Evaluate) and op.expression in candidates)]
def emit_operations(assignments, get_indices, emit_return_accumulate=True): """Makes an ordering of operations to evaluate a multi-root expression DAG. :arg assignments: Iterable of (variable, expression) pairs. The value of expression is written into variable upon execution. :arg get_indices: mapping from GEM nodes to an ordering of free indices :arg emit_return_accumulate: emit ReturnAccumulate nodes? Set to False if the output variables are not guaranteed zero on entry to the kernel. :returns: list of Impero terminals correctly ordered to evaluate the assignments """ # Prepare reference counts refcount = collect_refcount([e for v, e in assignments]) # Stage return operations staging = [] for variable, expression in assignments: if emit_return_accumulate and \ refcount[expression] == 1 and isinstance(expression, gem.IndexSum) \ and set(variable.free_indices) == set(expression.free_indices): staging.append(impero.ReturnAccumulate(variable, expression)) refcount[expression] -= 1 else: staging.append(impero.Return(variable, expression)) # Prepare data structures def push_node(node): queue.insert(get_indices(node), node) def push_op(op): queue.insert(op.loop_shape(get_indices), op) ops = [] stager = ReferenceStager(refcount, push_node) queue = Queue(functools.partial(handle, ops, push_op, stager.decref)) # Enqueue return operations for op in staging: push_op(op) # Schedule operations queue.process() # Assert that nothing left unprocessed assert stager.empty() # Return ops.reverse() return ops
def emit_operations(assignments, get_indices): """Makes an ordering of operations to evaluate a multi-root expression DAG. :arg assignments: Iterable of (variable, expression) pairs. The value of expression is written into variable upon execution. :arg get_indices: mapping from GEM nodes to an ordering of free indices :returns: list of Impero terminals correctly ordered to evaluate the assignments """ # Prepare reference counts refcount = collect_refcount([e for v, e in assignments]) # Stage return operations staging = [] for variable, expression in assignments: if refcount[expression] == 1 and isinstance(expression, gem.IndexSum) \ and set(variable.free_indices) == set(expression.free_indices): staging.append(impero.ReturnAccumulate(variable, expression)) refcount[expression] -= 1 else: staging.append(impero.Return(variable, expression)) # Prepare data structures def push_node(node): queue.insert(get_indices(node), node) def push_op(op): queue.insert(op.loop_shape(get_indices), op) ops = [] stager = ReferenceStager(refcount, push_node) queue = Queue(functools.partial(handle, ops, push_op, stager.decref)) # Enqueue return operations for op in staging: push_op(op) # Schedule operations queue.process() # Assert that nothing left unprocessed assert stager.empty() # Return ops.reverse() return ops
def pprint(expression_dags, context=global_context): refcount = collect_refcount(expression_dags) def force(node): if isinstance(node, gem.Variable): return False if node.shape: return True if isinstance(node, (gem.Constant, gem.Indexed, gem.FlexiblyIndexed)): return False return refcount[node] > 1 for node in post_traversal(expression_dags): if force(node): context.force_expression(node) name = context.expression(node) if name is not None: print(make_decl(node, name, context), '=', to_str(node, context, top=True)) for i, root in enumerate(expression_dags): name = "#%d" % (i + 1) print(make_decl(root, name, context), '=', to_str(root, context))