def set_rnn_bindings(gen, step_idx, in_eparams, out_eparams, cname, rnn_params, rnn_q): names = rnn_params.get_name_indexes() gen.bindings.append( CommentBindingList("Node {} inq {} weightsq {} outq {} biasesq {}", cname, rnn_q.in_qs[0], rnn_q.in_qs[names['r_2_i_w']], rnn_q.out_qs[0], rnn_q.in_qs[names['i_b']])) num_seq = num_sequences(rnn_params) if num_seq > 1: gen.locals.append( LocalArgInfo("int8", "S%s_StateInternal01" % step_idx)) if num_seq > 2: gen.locals.append( LocalArgInfo("int8", "S%s_StateInternal02" % step_idx)) i_state_eparams = in_eparams[names['i_state']] gen.bindings.append( NodeBindingList( cname, GNodeArgEdge(i_state_eparams, direction="GNA_INOUT"), GNodeArgEdge("S%s_StateInternal01" % step_idx, alias=i_state_eparams, direction="GNA_INOUT") if num_seq > 1 else NoArg(), GNodeArgEdge("S%s_StateInternal02" % step_idx, alias="S%s_CellInternal01" % step_idx, direction="GNA_INOUT") if num_seq > 2 else NoArg(), GNodeArgEdge(in_eparams[0]), GNodeArgEdge(in_eparams[names['r_2_i_w']]), GNodeArgEdge(in_eparams[names['i_b']]), GNodeArgEdge(out_eparams[0], direction="GNA_OUT"), GNodeArgNode(rnn_params, INFOS), GArgName(i_state_eparams.creating_node.reset_name)))
def bindings_generator(cls, gen, node, qrec, in_eparams, out_eparams, cname) -> bool: names = node.get_name_indexes() gen.bindings.append( CommentBindingList("Node {} inq {} outq {}", cname, qrec.in_qs[0], qrec.out_qs[0])) num_seq = num_sequences(node) step_idx = node.step_idx in_ctype = "char" if qrec.in_qs[0].dtype_bits == 8 else "short" if num_seq > 1: gen.locals.append( LocalArgInfo(f"unsigned {in_ctype}", f"S{step_idx}_StateInternal01")) if num_seq > 2: gen.locals.append( LocalArgInfo(f"unsigned {in_ctype}", f"S{step_idx}_StateInternal02")) i_state_eparams = in_eparams[names['h_state']] reset_name = i_state_eparams.creating_node.reset_name if node.rnn_states_as_inputs else "Reset" bindings = [ GNodeArgEdge(i_state_eparams, direction="GNA_INOUT"), GNodeArgEdge("S%s_StateInternal01" % step_idx, alias=i_state_eparams, direction="GNA_INOUT") if num_seq > 1 else NoArg(), GNodeArgEdge("S%s_StateInternal02" % step_idx, alias="S%s_StateInternal01" % step_idx, direction="GNA_INOUT") if num_seq > 2 else NoArg(), GNodeArgEdge(in_eparams[0]), GNodeArgNode(node, "scalenorm") ] for gate in ['r', 'z', 'h']: for inp_t in ['r', 'w']: bindings.append( GNodeArgEdge(in_eparams[names[f'{inp_t}_2_{gate}_w']])) if gate == 'h': bindings.append(GNodeArgEdge(in_eparams[names['w_h_b']])) bindings.append(GNodeArgEdge(in_eparams[names['r_h_b']])) else: bindings.append(GNodeArgEdge(in_eparams[names[f'{gate}_b']])) bindings.extend([ GNodeArgEdge(out_eparams[0], direction="GNA_OUT"), GNodeArgNode(node, INFOS), GArgName(reset_name) ]) gen.bindings.append(NodeBindingList(cname, *bindings)) return True
def set_lstm_bindings(gen, step_idx, in_eparams, out_eparams, cname, rnn_params, rnn_q): names = rnn_params.get_name_indexes() gen.bindings.append( CommentBindingList("Node {} inq {} outq {}", cname, rnn_q.in_qs[0], rnn_q.out_qs[0])) c_state_eparams = in_eparams[names['c_state']] i_state_eparams = in_eparams[names['i_state']] num_seq = num_sequences(rnn_params) if num_seq > 1: gen.locals.append(LocalArgInfo("int8", "S%s_CellInternal01" % step_idx)) gen.locals.append( LocalArgInfo("int8", "S%s_StateInternal01" % step_idx)) if num_seq > 2: gen.locals.append(LocalArgInfo("int8", "S%s_CellInternal02" % step_idx)) gen.locals.append( LocalArgInfo("int8", "S%s_StateInternal02" % step_idx)) reset_name = i_state_eparams.creating_node.reset_name if not rnn_params.rnn_states_as_inputs else "Reset" gen.bindings.append( NodeBindingList( cname, GNodeArgEdge(c_state_eparams, direction="GNA_INOUT"), GNodeArgEdge(i_state_eparams, direction="GNA_INOUT"), GNodeArgEdge("S%s_CellInternal01" % step_idx, alias=c_state_eparams, direction="GNA_INOUT") if num_seq > 1 else NoArg(), GNodeArgEdge("S%s_StateInternal01" % step_idx, alias=i_state_eparams, direction="GNA_INOUT") if num_seq > 1 else NoArg(), GNodeArgEdge("S%s_CellInternal02" % step_idx, alias="S%s_CellInternal01" % step_idx, direction="GNA_INOUT") if num_seq > 2 else NoArg(), GNodeArgEdge("S%s_StateInternal02" % step_idx, alias="S%s_CellInternal01" % step_idx, direction="GNA_INOUT") if num_seq > 2 else NoArg(), GNodeArgEdge(in_eparams[0]), GNodeArgEdge(in_eparams[names['r_2_f_w']]), GNodeArgEdge(in_eparams[names['f_b']]), GNodeArgEdge(in_eparams[names['r_2_i_w']]), GNodeArgEdge(in_eparams[names['i_b']]), GNodeArgEdge(in_eparams[names['r_2_c_w']]), GNodeArgEdge(in_eparams[names['c_b']]), GNodeArgEdge(in_eparams[names['r_2_o_w']]), GNodeArgEdge(in_eparams[names['o_b']]), GNodeArgEdge(out_eparams[0], direction="GNA_OUT"), GNodeArgNode(rnn_params, INFOS), GArgName(reset_name)))
def bindings_generator(cls, gen, node, qrec, in_eparams, out_eparams, cname) -> bool: names = node.get_name_indexes() gen.bindings.append( CommentBindingList("Node {} inq {} weightsq {} outq {} biasesq {}", cname, qrec.in_qs[0], qrec.in_qs[names['r_2_i_w']], qrec.out_qs[0], qrec.in_qs[names['i_b']])) num_seq = num_sequences(node) if num_seq > 1: gen.locals.append( LocalArgInfo("uint8", f"S{node.step_idx}_StateInternal01")) if num_seq > 2: gen.locals.append( LocalArgInfo("uint8", f"S{node.step_idx}_StateInternal02")) i_state_eparams = in_eparams[names['i_state']] reset_name = i_state_eparams.creating_node.reset_name if node.rnn_states_as_inputs else "Reset" gen.bindings.append( NodeBindingList( cname, GNodeArgEdge(i_state_eparams, direction="GNA_INOUT"), GNodeArgEdge(f"S{node.step_idx}_StateInternal01", alias=i_state_eparams, direction="GNA_INOUT") if num_seq > 1 else NoArg(), GNodeArgEdge(f"S{node.step_idx}_StateInternal02", alias=f"S{node.step_idx}_StateInternal01", direction="GNA_INOUT") if num_seq > 2 else NoArg(), GNodeArgEdge(in_eparams[0]), GNodeArgNode(node, "scalenorm"), GNodeArgEdge(in_eparams[names['r_2_i_w']]), GNodeArgEdge(in_eparams[names['i_2_i_w']]) if not node.rnn_same_inout_scale else NoArg(), GNodeArgEdge(in_eparams[names['i_b']]), GNodeArgEdge(out_eparams[0], direction="GNA_OUT"), GNodeArgNode(node, INFOS), GArgName(reset_name))) return True
def set_gru_bindings(gen, step_idx, in_eparams, out_eparams, cname, rnn_params, rnn_q): names = rnn_params.get_name_indexes() gen.bindings.append( CommentBindingList("Node {} inq {} outq {}", cname, rnn_q.in_qs[0], rnn_q.out_qs[0]) ) num_seq = num_sequences(rnn_params) if num_seq > 1: gen.locals.append(LocalArgInfo( "int8", "S%s_StateInternal01" % step_idx)) if num_seq > 2: gen.locals.append(LocalArgInfo( "int8", "S%s_StateInternal02" % step_idx)) i_state_eparams = in_eparams[names['h_state']] reset_name = i_state_eparams.creating_node.reset_name if not rnn_params.rnn_states_as_inputs else f"{rnn_params.name}_Reset" gen.bindings.append( NodeBindingList(cname, GNodeArgEdge(i_state_eparams, direction="GNA_INOUT"), GNodeArgEdge("S%s_StateInternal01" % step_idx, alias=i_state_eparams, direction="GNA_INOUT") if num_seq > 1 else NoArg(), GNodeArgEdge("S%s_StateInternal02" % step_idx, alias="S%s_CellInternal01" % step_idx, direction="GNA_INOUT") if num_seq > 2 else NoArg(), GNodeArgEdge(in_eparams[0]), GNodeArgEdge(in_eparams[names['r_2_r_w']]), GNodeArgEdge(in_eparams[names['r_b']]), GNodeArgEdge(in_eparams[names['r_2_z_w']]), GNodeArgEdge(in_eparams[names['z_b']]), GNodeArgEdge(in_eparams[names['r_2_h_w']]), GNodeArgEdge(in_eparams[names['w_h_b']]), GNodeArgEdge(in_eparams[names['r_h_b']]), GNodeArgEdge(out_eparams[0], direction="GNA_OUT"), GArgName(reset_name) ))