Ejemplo n.º 1
0
def set_rnn_bindings(gen, step_idx, in_eparams, out_eparams, cname, rnn_params,
                     rnn_q):
    names = rnn_params.get_name_indexes()

    gen.bindings.append(
        CommentBindingList("Node {} inq {} weightsq {} outq {} biasesq {}",
                           cname, rnn_q.in_qs[0],
                           rnn_q.in_qs[names['r_2_i_w']], rnn_q.out_qs[0],
                           rnn_q.in_qs[names['i_b']]))
    num_seq = num_sequences(rnn_params)
    if num_seq > 1:
        gen.locals.append(
            LocalArgInfo("int8", "S%s_StateInternal01" % step_idx))

    if num_seq > 2:
        gen.locals.append(
            LocalArgInfo("int8", "S%s_StateInternal02" % step_idx))

    i_state_eparams = in_eparams[names['i_state']]
    gen.bindings.append(
        NodeBindingList(
            cname, GNodeArgEdge(i_state_eparams, direction="GNA_INOUT"),
            GNodeArgEdge("S%s_StateInternal01" % step_idx,
                         alias=i_state_eparams,
                         direction="GNA_INOUT") if num_seq > 1 else NoArg(),
            GNodeArgEdge("S%s_StateInternal02" % step_idx,
                         alias="S%s_CellInternal01" % step_idx,
                         direction="GNA_INOUT") if num_seq > 2 else NoArg(),
            GNodeArgEdge(in_eparams[0]),
            GNodeArgEdge(in_eparams[names['r_2_i_w']]),
            GNodeArgEdge(in_eparams[names['i_b']]),
            GNodeArgEdge(out_eparams[0], direction="GNA_OUT"),
            GNodeArgNode(rnn_params, INFOS),
            GArgName(i_state_eparams.creating_node.reset_name)))
Ejemplo n.º 2
0
    def bindings_generator(cls, gen, node, qrec, in_eparams, out_eparams,
                           cname) -> bool:
        names = node.get_name_indexes()

        gen.bindings.append(
            CommentBindingList("Node {} inq {} outq {}", cname, qrec.in_qs[0],
                               qrec.out_qs[0]))
        num_seq = num_sequences(node)
        step_idx = node.step_idx
        in_ctype = "char" if qrec.in_qs[0].dtype_bits == 8 else "short"
        if num_seq > 1:
            gen.locals.append(
                LocalArgInfo(f"unsigned {in_ctype}",
                             f"S{step_idx}_StateInternal01"))

        if num_seq > 2:
            gen.locals.append(
                LocalArgInfo(f"unsigned {in_ctype}",
                             f"S{step_idx}_StateInternal02"))

        i_state_eparams = in_eparams[names['h_state']]
        reset_name = i_state_eparams.creating_node.reset_name if node.rnn_states_as_inputs else "Reset"
        bindings = [
            GNodeArgEdge(i_state_eparams, direction="GNA_INOUT"),
            GNodeArgEdge("S%s_StateInternal01" % step_idx,
                         alias=i_state_eparams,
                         direction="GNA_INOUT") if num_seq > 1 else NoArg(),
            GNodeArgEdge("S%s_StateInternal02" % step_idx,
                         alias="S%s_StateInternal01" % step_idx,
                         direction="GNA_INOUT") if num_seq > 2 else NoArg(),
            GNodeArgEdge(in_eparams[0]),
            GNodeArgNode(node, "scalenorm")
        ]
        for gate in ['r', 'z', 'h']:
            for inp_t in ['r', 'w']:
                bindings.append(
                    GNodeArgEdge(in_eparams[names[f'{inp_t}_2_{gate}_w']]))
            if gate == 'h':
                bindings.append(GNodeArgEdge(in_eparams[names['w_h_b']]))
                bindings.append(GNodeArgEdge(in_eparams[names['r_h_b']]))
            else:
                bindings.append(GNodeArgEdge(in_eparams[names[f'{gate}_b']]))

        bindings.extend([
            GNodeArgEdge(out_eparams[0], direction="GNA_OUT"),
            GNodeArgNode(node, INFOS),
            GArgName(reset_name)
        ])

        gen.bindings.append(NodeBindingList(cname, *bindings))

        return True
Ejemplo n.º 3
0
def set_lstm_bindings(gen, step_idx, in_eparams, out_eparams, cname,
                      rnn_params, rnn_q):

    names = rnn_params.get_name_indexes()

    gen.bindings.append(
        CommentBindingList("Node {} inq {} outq {}", cname, rnn_q.in_qs[0],
                           rnn_q.out_qs[0]))
    c_state_eparams = in_eparams[names['c_state']]
    i_state_eparams = in_eparams[names['i_state']]

    num_seq = num_sequences(rnn_params)
    if num_seq > 1:
        gen.locals.append(LocalArgInfo("int8",
                                       "S%s_CellInternal01" % step_idx))
        gen.locals.append(
            LocalArgInfo("int8", "S%s_StateInternal01" % step_idx))

    if num_seq > 2:
        gen.locals.append(LocalArgInfo("int8",
                                       "S%s_CellInternal02" % step_idx))
        gen.locals.append(
            LocalArgInfo("int8", "S%s_StateInternal02" % step_idx))

    reset_name = i_state_eparams.creating_node.reset_name if not rnn_params.rnn_states_as_inputs else "Reset"
    gen.bindings.append(
        NodeBindingList(
            cname, GNodeArgEdge(c_state_eparams, direction="GNA_INOUT"),
            GNodeArgEdge(i_state_eparams, direction="GNA_INOUT"),
            GNodeArgEdge("S%s_CellInternal01" % step_idx,
                         alias=c_state_eparams,
                         direction="GNA_INOUT") if num_seq > 1 else NoArg(),
            GNodeArgEdge("S%s_StateInternal01" % step_idx,
                         alias=i_state_eparams,
                         direction="GNA_INOUT") if num_seq > 1 else NoArg(),
            GNodeArgEdge("S%s_CellInternal02" % step_idx,
                         alias="S%s_CellInternal01" % step_idx,
                         direction="GNA_INOUT") if num_seq > 2 else NoArg(),
            GNodeArgEdge("S%s_StateInternal02" % step_idx,
                         alias="S%s_CellInternal01" % step_idx,
                         direction="GNA_INOUT") if num_seq > 2 else NoArg(),
            GNodeArgEdge(in_eparams[0]),
            GNodeArgEdge(in_eparams[names['r_2_f_w']]),
            GNodeArgEdge(in_eparams[names['f_b']]),
            GNodeArgEdge(in_eparams[names['r_2_i_w']]),
            GNodeArgEdge(in_eparams[names['i_b']]),
            GNodeArgEdge(in_eparams[names['r_2_c_w']]),
            GNodeArgEdge(in_eparams[names['c_b']]),
            GNodeArgEdge(in_eparams[names['r_2_o_w']]),
            GNodeArgEdge(in_eparams[names['o_b']]),
            GNodeArgEdge(out_eparams[0], direction="GNA_OUT"),
            GNodeArgNode(rnn_params, INFOS), GArgName(reset_name)))
Ejemplo n.º 4
0
    def bindings_generator(cls, gen, node, qrec, in_eparams, out_eparams,
                           cname) -> bool:
        names = node.get_name_indexes()

        gen.bindings.append(
            CommentBindingList("Node {} inq {} weightsq {} outq {} biasesq {}",
                               cname, qrec.in_qs[0],
                               qrec.in_qs[names['r_2_i_w']], qrec.out_qs[0],
                               qrec.in_qs[names['i_b']]))
        num_seq = num_sequences(node)
        if num_seq > 1:
            gen.locals.append(
                LocalArgInfo("uint8", f"S{node.step_idx}_StateInternal01"))

        if num_seq > 2:
            gen.locals.append(
                LocalArgInfo("uint8", f"S{node.step_idx}_StateInternal02"))

        i_state_eparams = in_eparams[names['i_state']]
        reset_name = i_state_eparams.creating_node.reset_name if node.rnn_states_as_inputs else "Reset"
        gen.bindings.append(
            NodeBindingList(
                cname, GNodeArgEdge(i_state_eparams, direction="GNA_INOUT"),
                GNodeArgEdge(f"S{node.step_idx}_StateInternal01",
                             alias=i_state_eparams,
                             direction="GNA_INOUT")
                if num_seq > 1 else NoArg(),
                GNodeArgEdge(f"S{node.step_idx}_StateInternal02",
                             alias=f"S{node.step_idx}_StateInternal01",
                             direction="GNA_INOUT")
                if num_seq > 2 else NoArg(), GNodeArgEdge(in_eparams[0]),
                GNodeArgNode(node, "scalenorm"),
                GNodeArgEdge(in_eparams[names['r_2_i_w']]),
                GNodeArgEdge(in_eparams[names['i_2_i_w']])
                if not node.rnn_same_inout_scale else NoArg(),
                GNodeArgEdge(in_eparams[names['i_b']]),
                GNodeArgEdge(out_eparams[0], direction="GNA_OUT"),
                GNodeArgNode(node, INFOS), GArgName(reset_name)))
        return True
Ejemplo n.º 5
0
def set_gru_bindings(gen, step_idx, in_eparams, out_eparams, cname,
                     rnn_params, rnn_q):
    names = rnn_params.get_name_indexes()

    gen.bindings.append(
        CommentBindingList("Node {} inq {} outq {}", cname,
                           rnn_q.in_qs[0],
                           rnn_q.out_qs[0])
    )
    num_seq = num_sequences(rnn_params)
    if num_seq > 1:
        gen.locals.append(LocalArgInfo(
            "int8", "S%s_StateInternal01" % step_idx))

    if num_seq > 2:
        gen.locals.append(LocalArgInfo(
            "int8", "S%s_StateInternal02" % step_idx))

    i_state_eparams = in_eparams[names['h_state']]
    reset_name = i_state_eparams.creating_node.reset_name if not rnn_params.rnn_states_as_inputs else f"{rnn_params.name}_Reset"
    gen.bindings.append(
        NodeBindingList(cname,
                        GNodeArgEdge(i_state_eparams, direction="GNA_INOUT"),
                        GNodeArgEdge("S%s_StateInternal01" % step_idx, alias=i_state_eparams,
                                     direction="GNA_INOUT") if num_seq > 1 else NoArg(),
                        GNodeArgEdge("S%s_StateInternal02" % step_idx, alias="S%s_CellInternal01" %
                                     step_idx, direction="GNA_INOUT") if num_seq > 2 else NoArg(),
                        GNodeArgEdge(in_eparams[0]),
                        GNodeArgEdge(in_eparams[names['r_2_r_w']]),
                        GNodeArgEdge(in_eparams[names['r_b']]),
                        GNodeArgEdge(in_eparams[names['r_2_z_w']]),
                        GNodeArgEdge(in_eparams[names['z_b']]),
                        GNodeArgEdge(in_eparams[names['r_2_h_w']]),
                        GNodeArgEdge(in_eparams[names['w_h_b']]),
                        GNodeArgEdge(in_eparams[names['r_h_b']]),
                        GNodeArgEdge(out_eparams[0], direction="GNA_OUT"),
                        GArgName(reset_name)
                        ))