def _get_weight_and_bias_for_lstm_cell(self, context): match = context.cell_match w_e = match.get_op("cell_kernel") w = get_weights_from_const_node(self.g, w_e) if w is None: return None # check https://www.tensorflow.org/versions/r1.8/api_docs/cc/class/tensorflow/ops/bias-add # for bias_add data format bias_add = match.get_op("bias_add") if bias_add is not None and bias_add.data_format != "NHWC": logger.debug("BiasAdd data_format is not NHWC, SKIP") return None b_e = match.get_op("cell_bias") if b_e is None: b = np.array([0 for i in range(len(w[0]))]).astype(w.dtype) else: b = get_weights_from_const_node(self.g, b_e) if b is None or b.shape[0] != w.shape[1]: logger.warning( "cell_kernel and cell_bias's dimensions does not match, skip" ) return None ft_bias_node = match.get_op("ft_bias") ft_bias = get_weights_from_const_node(self.g, ft_bias_node) if ft_bias is None: return None if not b.dtype == ft_bias.dtype: return None return {"weight": w, "bias": b, "ft_bias": ft_bias}
def _get_weight_and_bias_for_lstmblock_cell(self, context, i): cell_match = context.cell_match[i] w_node = cell_match.get_op("cell_kernel") w = get_weights_from_const_node(self.g, w_node) if w is None: logger.warning("Cannot find weight, SKIP") return None b_node = cell_match.get_op("cell_bias") b = get_weights_from_const_node(self.g, b_node) if b is None or b.shape[0] != w.shape[1]: logger.warning("cell_kernel and cell_bias's dimension doesn't match, SKIP") return None lstm_block_cell = cell_match.get_op("lstm_block_cell") ft_bias_val = np.array( lstm_block_cell.get_attr("forget_bias").f, dtype=b.dtype ) return { "weight": w, "bias": b, "ft_bias": ft_bias_val }
def get_weight_and_bias(self, match): # if one of them is not match, just return w_e = match.get_op("cell_kernel") w = get_weights_from_const_node(self.g, w_e) if not w: return None # check https://www.tensorflow.org/versions/r1.8/api_docs/cc/class/tensorflow/ops/bias-add # for bias_add data format bias_add = match.get_op("bias_add") if bias_add.data_format != "NHWC": log.debug("BiasAdd data_format is not NHWC, SKIP") return None b_e = match.get_op("cell_bias") b = get_weights_from_const_node(self.g, b_e) if not b or b.value.shape[0] != w.value.shape[1]: log.warning( "cell_kernel and cell_bias's dimensions does not match, skip") return None ft_bias = match.get_op("ft_bias") ft = get_weights_from_const_node(self.g, ft_bias) if not ft: return None if not (ft.value == 1 and self.g.get_dtype(b_e.output[0]) == self.g.get_dtype(ft_bias.output[0])): return None return RnnWeights(w, b, ft)
def _get_weight_and_bias_for_lstm_cell(self, context): match = context.cell_match w_e = match.get_op("cell_kernel") w = get_weights_from_const_node(self.g, w_e) if not w: return None # check https://www.tensorflow.org/versions/r1.8/api_docs/cc/class/tensorflow/ops/bias-add # for bias_add data format bias_add = match.get_op("bias_add") if bias_add.data_format != "NHWC": log.debug("BiasAdd data_format is not NHWC, SKIP") return None b_e = match.get_op("cell_bias") b = get_weights_from_const_node(self.g, b_e) if not b or b.value.shape[0] != w.value.shape[1]: log.warning( "cell_kernel and cell_bias's dimensions does not match, skip") return None ft_bias = match.get_op("ft_bias") ft = get_weights_from_const_node(self.g, ft_bias) if not ft: return None if not b.dtype == ft.dtype: return None return RnnWeights(w, b, ft)
def get_weight_and_bias(self, match): gate_kernel = get_weights_from_const_node(self.g, match.get_op("gate_kernel")) gate_bias = get_weights_from_const_node(self.g, match.get_op("gate_bias")) hidden_kernel = get_weights_from_const_node(self.g, match.get_op("hidden_kernel")) hidden_bias = get_weights_from_const_node(self.g, match.get_op("hidden_bias")) if not all([gate_kernel, gate_bias, hidden_kernel, hidden_bias]): log.debug("rnn weights check failed, skip") return None log.debug("find needed weights") res = {"gate_kernel": gate_kernel, "gate_bias": gate_bias, "hidden_kernel": hidden_kernel, "hidden_bias": hidden_bias} return res
def get_weight_and_bias(self, context): match = context.cell_match gate_kernel = get_weights_from_const_node(self.g, match.get_op("gate_kernel")) gate_bias = get_weights_from_const_node(self.g, match.get_op("gate_bias")) res = {"gate_kernel": gate_kernel, "gate_bias": gate_bias} # differ on memory gate: # GRUCell: h'_t = tanh(concat(x_t, r_t .* h_t-1) * W + b) # CudnnCompatibleGRUCell: h'_t = tanh(x_t * W_x + b_x + r_t .* (h_t-1 * W_h + b_h)) if self.gru_cell_type == RNNUnitType.CudnnCompatibleGRUCell: hidden_state_kernel = get_weights_from_const_node( self.g, match.get_op("hidden_state_kernel")) hidden_state_bias = get_weights_from_const_node( self.g, match.get_op("hidden_state_bias")) hidden_input_kernel = get_weights_from_const_node( self.g, match.get_op("hidden_input_kernel")) hidden_input_bias = get_weights_from_const_node( self.g, match.get_op("hidden_input_bias")) if not all(val is not None for val in [ hidden_state_kernel, hidden_state_bias, hidden_input_kernel, hidden_input_bias ]): logger.debug("rnn weights check failed, skip") return None hidden_kernel = np.concatenate( [hidden_input_kernel, hidden_state_kernel]) # apply the linear transformation before multiplying by the output of reset gate context.attributes["linear_before_reset"] = 1 res["hidden_kernel"] = hidden_kernel res["hidden_bias"] = hidden_input_bias # recurrence bias for hidden gate res["Rb_h"] = hidden_state_bias elif self.gru_cell_type in [ RNNUnitType.GRUCell, RNNUnitType.GRUBlockCell ]: hidden_kernel = get_weights_from_const_node( self.g, match.get_op("hidden_kernel")) hidden_bias = get_weights_from_const_node( self.g, match.get_op("hidden_bias")) res["hidden_kernel"] = hidden_kernel res["hidden_bias"] = hidden_bias if not all(val is not None for val in res.values()): logger.debug("rnn weights check failed, skip") return None logger.debug("find needed weights") return res
def get_weight_and_bias(self, match): node = match.get_op("GRUBlockCell") # from tf, it can be known that, the inputs index and meaning of input data is: # 0-input, 1-state, 2-gate_kernel, 3-hidden_kernel, 4-gate_bias, 5-hidden_bias gate_kernel = get_weights_from_const_node(node.inputs[2].inputs[0]) gate_bias = get_weights_from_const_node(node.inputs[4].inputs[0]) hidden_kernel = get_weights_from_const_node(node.inputs[3].inputs[0]) hidden_bias = get_weights_from_const_node(node.inputs[5].inputs[0]) if not all([gate_kernel, gate_bias, hidden_kernel, hidden_bias]): log.error("rnn weights check failed, skip") sys.exit(-1) return None log.debug("find needed weights") res = { "gate_kernel": gate_kernel, "gate_bias": gate_bias, "hidden_kernel": hidden_kernel, "hidden_bias": hidden_bias } return res
def _get_weight_and_bias_for_lstmblock_cell(self, context): cell_match = context.cell_match w_node = cell_match.get_op("cell_kernel") w = get_weights_from_const_node(self.g, w_node) if not w: log.warning("Cannot find weight, SKIP") return None b_node = cell_match.get_op("cell_bias") b = get_weights_from_const_node(self.g, b_node) if not b or b.value.shape[0] != w.value.shape[1]: log.warning( "cell_kernel and cell_bias's dimension doesn't match, SKIP") return None lstm_block_cell = cell_match.get_op("lstm_block_cell") ft_bias_val = np.array(lstm_block_cell.get_attr("forget_bias").f, dtype=b.dtype) ft_bias = RnnWeight(None, ft_bias_val, ft_bias_val.dtype) return RnnWeights(w, b, ft_bias)