コード例 #1
0
    def __init__(self, node_name, cname, params, qrec, gen_ctrl=None):
        if gen_ctrl is None:
            gen_ctrl = GenCtrl(None, cname=cname)
        else:
            gen_ctrl.cname = cname

        if params.hard_act:
            gen_ctrl.rnn_use_hardact = 1

        names = {val: idx for idx, val in enumerate(RNNParameters.INPUT_NAMES)}
        in_qs = qrec.in_qs
        if in_qs[names['i_2_i_w']].bits != in_qs[names['r_2_i_w']].bits:
            ValueError(f'bit width of gates differs in {params.name}')

        attrs = {
            'bias_size': in_qs[names['i_b']].dtype_bits // 8,
            'feat_size': -in_qs[0].dtype_bits // 8,
            'filter_bits': in_qs[names['i_2_i_w']].bits,
            'n_cells': params.n_cells,
            'k0': params.n_input_cells,
            'k1': params.n_output_cells,
            'dim_state': params.n_states,
            'dim_in': params.n_inputs,
            'always_reset': 0,
            'revert': 1 if params.revert else 0,
        }

        extra_attrs = {'cname': cname, 'node_name': node_name}
        super().__init__(attrs, extra_attrs, gen_ctrl=gen_ctrl)
コード例 #2
0
    def __init__(self, node_name, cname, params, qrec, gen_ctrl=None):
        if gen_ctrl is None:
            gen_ctrl = GenCtrl(None, cname=cname)
        else:
            gen_ctrl.cname = cname

        if params.hard_act:
            gen_ctrl.rnn_use_hardact = 1
            gen_ctrl.gate_prenorm = qrec.cache['i_2_f_q'].pre_normalization

        names = {
            val: idx
            for idx, val in enumerate(LSTMParameters.INPUT_NAMES)
        }
        in_qs = qrec.in_qs

        w_bits = None
        for gate in ['f', 'i', 'c', 'o']:
            for inp_t in ['r', 'i']:
                if w_bits is None:
                    w_bits = in_qs[names[f'{inp_t}_2_{gate}_w']].bits
                elif w_bits != in_qs[names[f'{inp_t}_2_{gate}_w']].bits:
                    ValueError(f'bit width of gates differs in {params.name}')

        attrs = {
            'bias_size': in_qs[names['i_b']].dtype_bits // 8,
            'feat_size': -in_qs[0].dtype_bits // 8,
            'filter_bits': w_bits,
            'n_cells': params.n_cells,
            'k0': params.n_input_cells,
            'k1': params.n_output_cells,
            'dim_state': params.n_states,
            'dim_in': params.n_inputs,
            'always_reset': 0,
            'revert': 1 if params.revert else 0,
        }

        extra_attrs = {'cname': cname, 'node_name': node_name}
        super().__init__(attrs, extra_attrs, gen_ctrl=gen_ctrl)