def create_invocation_end(ctxt: InvocationContext, builder: GraphTypeBuilder): for set in ctxt.get_all_involved_sets(): with builder.subst(set=set.id): handler = """ assert(deviceState->{invocation}_in_progress); assert( deviceState->{invocation}_read_send_mask == 0 ); assert( deviceState->{invocation}_write_send_mask == 0 ); assert( deviceState->{invocation}_read_recv_count == 0 ); assert( deviceState->{invocation}_write_recv_count == deviceProperties->{invocation}_write_recv_total ); deviceState->{invocation}_in_progress=0; deviceState->{invocation}_read_recv_count=0; deviceState->{invocation}_write_recv_count=0; """ if set == ctxt.stat.iter_set: for (ai, arg) in ctxt.global_writes: handler += """ copy_value(message->global_{}, deviceState->global_{}); """.format(arg.global_.id, arg.global_.id) else: for (ai, arg) in ctxt.global_writes: handler += """ zero_value(message->global_{}); """.format(arg.global_.id) builder.add_output_pin("set_{set}", "{invocation}_end", "{invocation}_end", handler)
def create_indirect_read_recvs(ctxt: InvocationContext, builder: GraphTypeBuilder): for (ai, arg) in ctxt.indirect_reads: with builder.subst(set=ctxt.stat.iter_set.id, index=ai, dat=arg.dat.id): props = {"index": DataType(numpy.uint8, shape=())} builder.add_input_pin( "set_{set}", "{invocation}_arg{index}_read_recv", "dat_{dat}", props, None, """ assert(deviceState->{invocation}_read_recv_count < {invocation}_{set}_read_recv_total); // Standard edge trigger for start of invocation if(!deviceState->{invocation}_in_progress){{ deviceState->{invocation}_in_progress=1; deviceState->{invocation}_read_send_mask = {invocation}_{set}_read_send_mask_all; }} deviceState->{invocation}_read_recv_count++; """) if arg.index < 0: builder.extend_input_pin_handler( "set_{set}", "{invocation}_arg{index}_read_recv", """ copy_value(deviceState->{invocation}_arg{index}_buffer[edgeProperties->index], message->value); """) else: builder.extend_input_pin_handler( "set_{set}", "{invocation}_arg{index}_read_recv", """ copy_value(deviceState->{invocation}_arg{index}_buffer, message->value); """)
def create_indirect_read_sends(ctxt: InvocationContext, builder: GraphTypeBuilder): for dat in ctxt.get_indirect_read_dats(): with builder.subst(dat=dat.id, set=dat.set.id): builder.add_output_pin( "set_{set}", "{invocation}_dat_{dat}_read_send", "dat_{dat}", """ assert(deviceState->{invocation}_read_send_mask & RTS_FLAG_{invocation}_dat_{dat}_read_send); copy_value(message->value, deviceState->dat_{dat}); deviceState->{invocation}_read_send_mask &= ~RTS_FLAG_{invocation}_dat_{dat}_read_send; """)
def create_indirect_landing_pads(ctxt: InvocationContext, builder: GraphTypeBuilder): for (ai, arg) in (ctxt.indirect_reads | ctxt.indirect_writes): with builder.subst(index=ai, set=ctxt.stat.iter_set.id): if arg.index < 0: vtype = DataType(arg.dat.data_type.dtype, (-arg.index, ) + arg.dat.data_type.shape) builder.add_device_state("set_{set}", "{invocation}_arg{index}_buffer", vtype) else: builder.add_device_state("set_{set}", "{invocation}_arg{index}_buffer", arg.dat.data_type)
def create_indirect_write_recvs(ctxt: InvocationContext, builder: GraphTypeBuilder): for (ai, arg) in ctxt.indirect_writes: with builder.subst(set=arg.to_set.id, index=ai, dat=arg.dat.id, arity=arg.map.arity): if arg.index >= 0: props = None message_type = "dat_{dat}" else: props = {"index": scalar_uint32} message_type = "dat_{dat}_x{arity}" handler = """ assert(deviceState->{invocation}_write_recv_count < deviceProperties->{invocation}_write_recv_total); // Standard edge trigger for start of invocation if(!deviceState->{invocation}_in_progress){{ deviceState->{invocation}_in_progress=1; deviceState->{invocation}_read_send_mask = {invocation}_{set}_read_send_mask_all; }} deviceState->{invocation}_write_recv_count++; """ if arg.access_mode == AccessMode.WRITE or arg.access_mode == AccessMode.RW: if arg.index >= 0: handler += """ copy_value(deviceState->dat_{dat}, message->value); """ else: handler += """ copy_value(deviceState->dat_{dat}, message->value[edgeProperties->index]); """ elif arg.access_mode == AccessMode.INC: if arg.index >= 0: handler += """ inc_value(deviceState->dat_{dat}, message->value); """ else: handler += """ inc_value(deviceState->dat_{dat}, message->value[edgeProperties->index]); """ else: raise RuntimeError("Unexpected access mode {}".format( arg.access_mode)) builder.add_input_pin("set_{set}", "{invocation}_arg{index}_write_recv", message_type, props, None, handler)
def create_read_send_masks(ctxt: InvocationContext, builder: GraphTypeBuilder): read_sends = {s: set() for s in ctxt.get_all_involved_sets()} for (ai, arg) in ctxt.indirect_reads: read_sends[arg.to_set].add(arg.dat) for (set_, dats) in read_sends.items(): read_send_set = ["0"] + [ builder.s("RTS_FLAG_{invocation}_dat_{dat}_read_send", dat=dat.id) for dat in dats ] read_send_set = "(" + "|".join(read_send_set) + ")" with builder.subst(set=set_.id, read_send_set=read_send_set): builder.add_device_shared_code( "set_{set}", """ const uint32_t {invocation}_{set}_read_send_mask_all = {read_send_set}; """)
def create_invocation_execute(ctxt: InvocationContext, builder: GraphTypeBuilder): for set in ctxt.get_all_involved_sets(): with builder.subst(set=set.id, kernel=ctxt.stat.name): handler = """ assert( deviceState->{invocation}_in_progress ); assert( deviceState->{invocation}_read_recv_count == {invocation}_{set}_read_recv_total ); deviceState->{invocation}_read_recv_count=0; // Allows us to tell that execute has happened deviceState->{invocation}_write_recv_count++; // The increment for this "virtual" receive deviceState->{invocation}_write_send_mask = {invocation}_{set}_write_send_mask_all; // In device-local shared code """ if set == ctxt.stat.iter_set: handler += "kernel_{kernel}(\n" for (ai, arg) in enumerate(ctxt.stat.arguments): if isinstance(arg, GlobalArgument): if ( ai, arg ) in ctxt.mutable_global_reads | ctxt.global_writes: handler += builder.s(" deviceState->global_{id}", id=arg.global_.id) else: handler += builder.s( " (double*)graphProperties->global_{id}", id=arg.global_.id) elif isinstance(arg, DirectDatArgument): handler += builder.s(" deviceState->dat_{dat}", dat=arg.dat.id) elif isinstance(arg, IndirectDatArgument): handler += builder.s( " deviceState->{invocation}_arg{index}_buffer", index=ai) else: raise RuntimeError("Unexpected arg type.") if ai + 1 < len(ctxt.stat.arguments): handler += "," handler += "\n" handler += ");\n" builder.add_output_pin("set_{set}", "{invocation}_execute", "executeMsgType", handler)
def create_invocation_begin(ctxt: InvocationContext, builder: GraphTypeBuilder): for set in ctxt.get_all_involved_sets(): with builder.subst(set=set.id): handler = """ // Standard edge trigger for start of invocation if(!deviceState->{invocation}_in_progress){{ deviceState->{invocation}_in_progress=1; deviceState->{invocation}_read_send_mask = {invocation}_{set}_read_send_mask_all; }} deviceState->{invocation}_read_recv_count++; """ if set == ctxt.stat.iter_set: for (ai, arg) in ctxt.mutable_global_reads: handler += """ copy_value(deviceState->global_{}, message->global_{}); """.format(arg.global_.id, arg.global_.id) builder.add_input_pin("set_{set}", "{invocation}_begin", "{invocation}_begin", None, None, handler)
def create_rts(ctxt: InvocationContext, builder: GraphTypeBuilder): for set in ctxt.get_all_involved_sets(): with builder.subst(set=set.id): rts = """ if(deviceState->{invocation}_in_progress) {{ *readyToSend |= deviceState->{invocation}_read_send_mask; *readyToSend |= deviceState->{invocation}_write_send_mask; if(deviceState->{invocation}_read_recv_count == {invocation}_{set}_read_recv_total){{ *readyToSend |= RTS_FLAG_{invocation}_execute; }} if(deviceState->{invocation}_read_recv_count==0 && deviceState->{invocation}_write_recv_count==deviceProperties->{invocation}_write_recv_total && deviceState->{invocation}_read_send_mask==0 && deviceState->{invocation}_write_send_mask==0 ){{ *readyToSend |= RTS_FLAG_{invocation}_end; }} }} // if(deviceState->{invocation}_in_progress) """ builder.add_rts_clause("set_{set}", rts)
def create_read_recv_counts(ctxt: InvocationContext, builder: GraphTypeBuilder): # Start off at 1, as that represents {invocation}_begin read_recvs = {s: 1 for s in ctxt.get_all_involved_sets()} for (ai, arg) in ctxt.indirect_reads: if arg.index < 0: read_recvs[ arg. iter_set] += -arg.index # We'll receive multiple values on this pin else: read_recvs[arg.iter_set] += 1 # Only get a single value for (set_, dats) in read_recvs.items(): with builder.subst(set=set_.id, read_recv_count=read_recvs[set_]): builder.add_device_shared_code( "set_{set}", """ const uint32_t {invocation}_{set}_read_recv_total = {read_recv_count}; """)
def create_all_messages(ctxt: InvocationContext, builder: GraphTypeBuilder): read_globals = {} for (ai, arg) in ctxt.mutable_global_reads: # If global appears twice, then some keys may get overwritten with same type read_globals["global_{}".format(arg.global_.id)] = arg.data_type builder.create_message_type("{invocation}_begin", read_globals) write_globals = {} for (ai, arg) in ctxt.global_writes: # If global appears twice, then some keys may get overwritten with same type write_globals["global_{}".format(arg.global_.id)] = arg.data_type builder.create_message_type("{invocation}_end", write_globals) for (ai, arg) in ctxt.indirect_reads | ctxt.indirect_writes: dat = arg.dat with builder.subst(dat=dat.id, arity=-arg.index): builder.merge_message_type("dat_{dat}", {"value": dat.data_type}) if arg.index < 0 and (ai, arg) in ctxt.indirect_writes: dt = DataType(dat.data_type.dtype, (-arg.index, ) + dat.data_type.shape) builder.merge_message_type("dat_{dat}_x{arity}", {"value": dt})
def create_state_tracking_variables(ctxt: InvocationContext, builder: GraphTypeBuilder): for set in ctxt.get_all_involved_sets(): with builder.subst(set=set.id): builder.add_device_state("set_{set}", "{invocation}_in_progress", scalar_uint32) builder.add_device_state("set_{set}", "{invocation}_read_send_mask", scalar_uint32) builder.add_device_state("set_{set}", "{invocation}_read_recv_count", scalar_uint32) builder.add_device_state("set_{set}", "{invocation}_write_send_mask", scalar_uint32) builder.add_device_state("set_{set}", "{invocation}_write_recv_count", scalar_uint32) builder.add_device_property("set_{set}", "{invocation}_write_recv_total", scalar_uint32)
def create_indirect_write_sends(ctxt: InvocationContext, builder: GraphTypeBuilder): for (ai, arg) in ctxt.indirect_writes: with builder.subst(index=ai, set=arg.iter_set.id, dat=arg.dat.id, arity=-arg.index): if arg.index >= 0: builder.add_output_pin( "set_{set}", "{invocation}_arg{index}_write_send", "dat_{dat}", """ assert(deviceState->{invocation}_write_send_mask & RTS_FLAG_{invocation}_arg{index}_write_send); copy_value(message->value, deviceState->{invocation}_arg{index}_buffer); deviceState->{invocation}_write_send_mask &= ~RTS_FLAG_{invocation}_arg{index}_write_send; """) else: builder.add_output_pin( "set_{set}", "{invocation}_arg{index}_write_send", "dat_{dat}_x{arity}", """ assert(deviceState->{invocation}_write_send_mask & RTS_FLAG_{invocation}_arg{index}_write_send); copy_value(message->value, deviceState->{invocation}_arg{index}_buffer); deviceState->{invocation}_write_send_mask &= ~RTS_FLAG_{invocation}_arg{index}_write_send; """)
def create_write_send_masks(ctxt: InvocationContext, builder: GraphTypeBuilder): """ Creates a mask called {invocation}_{set}_write_send_mask_all for each involved set.""" write_sends = {s: set() for s in ctxt.get_all_involved_sets()} for (ai, arg) in ctxt.indirect_writes: write_sends[arg.iter_set].add((ai, arg)) for (s, args) in write_sends.items(): for (ai, arg) in args: logging.info( "write send mask set=%s, arg=%s, to_set=%s, iter_set=%s", s.id, arg, arg.to_set.id, arg.iter_set.id) write_send_set = ["0"] + [ builder.s("RTS_FLAG_{invocation}_arg{index}_write_send", index=ai) for (ai, arg) in args ] write_send_set = "(" + "|".join(write_send_set) + ")" with builder.subst(set=s.id, write_send_set=write_send_set): builder.add_device_shared_code( "set_{set}", """ const uint32_t {invocation}_{set}_write_send_mask_all = {write_send_set}; """)
def create_global_landing_pads(ctxt: InvocationContext, builder: GraphTypeBuilder): for (ai, arg) in (ctxt.mutable_global_reads | ctxt.global_writes): with builder.subst(name=arg.global_.id, set=ctxt.stat.iter_set.id): builder.merge_device_state("set_{set}", "global_{name}", arg.data_type)
def create_invocation_tester(testIndex: int, isLast: bool, ctxt: InvocationContext, builder: GraphTypeBuilder): with builder.subst(testIndex=testIndex, isLast=int(isLast)): handler = """ assert(deviceState->end_received==0); assert(deviceState->test_state==2*{testIndex}); deviceState->test_state++; // Odd value means we are waiting for the return deviceState->end_received=0; """ #for (ai,arg) in ctxt.mutable_global_reads: # handler+=builder.s(""" # copy_value(message->global_{name}, graphProperties->test_{invocation}_{name}_in); # """,name=arg.global_.id) builder.add_output_pin("tester", "{invocation}_begin", "{invocation}_begin", handler) handler = """ assert(deviceState->test_state==2*{testIndex}+1); assert(deviceState->end_received < graphProperties->{invocation}_total_responding_devices); deviceState->end_received++; """ # Collect any inc's to global values for (ai, arg) in ctxt.global_writes: assert arg.access_mode == AccessMode.INC handler += builder.s(""" inc_value(deviceState->global_{name}, message->global_{name}); """, name=arg.global_.id) # Check whether we have finished handler += """ if(deviceState->end_received == graphProperties->{invocation}_total_responding_devices){{ """ # Remove for now - not clear how to do this. if False: # ... and if so, try to check the results are "right" for (ai, arg) in ctxt.global_writes: handler += builder.s(""" check_value(deviceState->global_{name}, graphProperties->test_{invocation}_{name}_out); """, name=arg.global_.id) handler += """ if( {isLast} ){{ handler_exit(0); }}else{{ deviceState->test_state++; // start the next invocation deviceState->end_received=0; }} }} """ builder.add_input_pin("tester", "{invocation}_end", "{invocation}_end", None, None, handler) builder.add_rts_clause( "tester", """ if(deviceState->test_state==2*{testIndex}){{ *readyToSend = RTS_FLAG_{invocation}_begin; }} """)
def compile_global_controller(gi: str, spec: SystemSpecification, builder: GraphTypeBuilder, code: Statement): create_controller_states( code) # Make sure every statement has an entry state builder.add_device_state(gi, "rts", scalar_uint32) builder.add_device_state(gi, "state", scalar_uint32) # This will be used as a hidden global, and captures the value # of the condition for If and While assert "_cond_" in spec.globals start_state = get_statement_state(code) finish_state = make_state() builder.add_rts_clause(gi, "*readyToSend = deviceState->rts;\n") with builder.subst(start_state=start_state): handler = """ deviceState->rts=RTS_FLAG_control; deviceState->state={start_state}; handler_log(4, "rts=%x, state=%d", deviceState->rts, deviceState->state); """ for mg in spec.globals.values(): if isinstance(mg, MutableGlobal): handler += """ copy_value(deviceState->global_{global_}, graphProperties->init_global_{global_}); """.format(global_=mg.id) builder.add_input_pin(gi, "__init__", "__init__", None, None, handler) handler = raw(""" handler_log(4, "rts=%x, state=%d", deviceState->rts, deviceState->state); *doSend=0; // disable this send... deviceState->rts=RTS_FLAG_control; // ... but say that we want to send again (by default) switch(deviceState->state){ """) handler += render_controller_statement(code, finish_state) handler += """ case {finish_state}: handler_log(4, "Hit finish state."); handler_exit(0); break; default: handler_log(3, "Unknown state id %d.", deviceState->state); assert(0); }} """.format(finish_state=finish_state) builder.create_message_type("control", {}) builder.add_output_pin(gi, "control", "control", handler) for user_code in find_scalars_in_code(code): name = user_code.id assert user_code.ast src = mini_op2.framework.kernel_translator.scalar_to_c( user_code.ast, name) builder.add_device_shared_code("controller", raw(src)) for k in code.all_statements(): if isinstance(k, While): name = k.id src = mini_op2.framework.kernel_translator.scalar_to_c( k.expr_ast, name) builder.add_device_shared_code("controller", raw(src))
def sync_compiler(spec: SystemSpecification, code: Statement): builder = GraphTypeBuilder("op2_inst") builder.add_shared_code_raw(r""" #include <cmath> #include <cstdio> #include <cstdarg> void fprintf_stderr(const char *msg, ...) { va_list v; va_start(v,msg); vfprintf(stderr, msg, v); fprintf(stderr, "\n"); va_end(v); } template<class T,unsigned N> void copy_value(T (&x)[N], const T (&y)[N]){ for(unsigned i=0; i<N; i++){ x[i]=y[i]; } } template<class T,unsigned N,unsigned M> void copy_value(T (&x)[N][M], const T (&y)[N][M]){ for(unsigned i=0; i<N; i++){ for(unsigned j=0; j<M; j++){ x[i][j]=y[i][j]; } } } /*template<class T> void copy_value(T &x, const T (&y)[1]){ x[0]=y[0]; }*/ template<class T,unsigned N> void inc_value(T (&x)[N], const T (&y)[N]){ for(unsigned i=0; i<N; i++){ x[i]+=y[i]; } } template<class T,unsigned N> void zero_value(T (&x)[N]){ for(unsigned i=0; i<N; i++){ x[i]=0; } } /* Mainly for debug. Used to roughly check that a calculated value is correct, based on "known-good" pre-calculated values. Leaves a lot to be desired... */ template<class T,unsigned N> void check_value(T (&got)[N], T (&ref)[N] ){ for(unsigned i=0; i<N; i++){ auto diff=std::abs( got[i] - ref[i] ); assert( diff < 1e-6 ); // Bleh... } } """) builder.create_message_type("executeMsgType", {}) # Support two kinds of global. Only one can be wired into an instance. builder.create_device_type( "controller") # This runs the actual program logic builder.create_device_type( "tester") # This solely tests each invocation in turn builder.add_device_state("tester", "test_state", DataType(shape=(), dtype=numpy.uint32)) builder.add_device_state("tester", "end_received", DataType(shape=(), dtype=numpy.uint32)) builder.add_device_state("controller", "end_received", DataType(shape=(), dtype=numpy.uint32)) builder.add_device_state("controller", "invocation_index", DataType(shape=(), dtype=numpy.uint32)) for global_ in spec.globals.values(): if isinstance(global_, MutableGlobal): builder.add_device_state("controller", "global_{}".format(global_.id), global_.data_type) builder.add_device_state("tester", "global_{}".format(global_.id), global_.data_type) builder.add_graph_property("init_global_{}".format(global_.id), global_.data_type) elif isinstance(global_, ConstGlobal): builder.add_graph_property("global_{}".format(global_.id), global_.data_type) else: raise RuntimeError("Unexpected global type : {}", type(global_)) builder.merge_message_type("__init__", {}) for s in spec.sets.values(): with builder.subst(set="set_" + s.id): builder.create_device_type("{set}") init_handler = "" for dat in s.dats.values(): with builder.subst(dat=dat.id): builder.add_device_property("{set}", "init_dat_{dat}", dat.data_type) builder.add_device_state("{set}", "dat_{dat}", dat.data_type) init_handler += builder.s( " copy_value(deviceState->dat_{dat}, deviceProperties->init_dat_{dat});\n" ) builder.add_input_pin("{set}", "__init__", "__init__", None, None, init_handler) kernels = find_kernels_in_code(code) emitted_kernels = set() for (i, stat) in enumerate(kernels): ctxt = InvocationContext(spec, stat) with builder.subst(invocation=ctxt.invocation): compile_invocation(spec, builder, ctxt, emitted_kernels) create_invocation_tester(i, i + 1 == len(kernels), ctxt, builder) create_invocation_controller(ctxt, builder) compile_global_controller("controller", spec, builder, code) return builder