def find_children(program, idx, value_map): ''' Given a specific r_call and arguments to evaluate it completely return all other program locations that read from the output of r_call ''' r_call = program[idx] ib = eval_remote_call(r_call, value_map) children = [] for inst in ib.instrs: if (not isinstance(inst, lp.RemoteWrite)): continue assert (isinstance(inst, lp.RemoteWrite)) page = inst.matrix offset = inst.bidxs for p_idx in program.keys(): r_call_abstract_with_scope = program[p_idx] r_call_abstract = r_call_abstract_with_scope.remote_call scope = copy_scope(r_call_abstract_with_scope.scope) for i, arg in enumerate(r_call_abstract.args): if (not isinstance(arg, IndexExpr)): continue abstract_page, abstract_offset = eval_index_expr(arg, scope, dummify=True) if (abstract_page != page): continue offset_types = [x.type for x in arg.indices] local_children = template_match(page, offset, abstract_page, abstract_offset, offset_types, scope) children += [(p_idx, x) for x in local_children] return integerify_solutions(utils.remove_duplicates(children))
def find_parents(program, idx, value_map): ''' Given a specific r_call and arguments to evaluate it completely return the program locations that writes to the input of r_call ''' r_call = program[idx] ib = eval_remote_call(r_call, value_map) parents = [] for inst in ib.instrs: if (not isinstance(inst, lp.RemoteRead)): continue assert (isinstance(inst, lp.RemoteRead)) page = inst.matrix offset = inst.bidxs for p_idx in program.keys(): r_call_abstract_with_scope = program[p_idx] r_call_abstract = r_call_abstract_with_scope.remote_call scope = copy_scope(r_call_abstract_with_scope.scope) for i, output in enumerate(r_call_abstract.output): if (not isinstance(output, IndexExpr)): continue abstract_page, abstract_offset = eval_index_expr(output, scope, dummify=True) if (abstract_page != page): continue offset_types = [x.type for x in output.indices] local_parents = template_match(page, offset, abstract_page, abstract_offset, offset_types, scope) if (len(local_parents) > 1): # No single IndexExpr should have multiple parents raise Exception( "Invalid Program Graph, LambdaPackPrograms must be SSA" ) parents += [(p_idx, x) for x in local_parents] return integerify_solutions(utils.remove_duplicates(parents))