def _get_weight_and_bias_for_lstm_cell(self, context):
        match = context.cell_match

        w_e = match.get_op("cell_kernel")
        w = get_weights_from_const_node(self.g, w_e)
        if not w:
            return None

        # check https://www.tensorflow.org/versions/r1.8/api_docs/cc/class/tensorflow/ops/bias-add
        # for bias_add data format
        bias_add = match.get_op("bias_add")
        if bias_add.data_format != "NHWC":
            log.debug("BiasAdd data_format is not NHWC, SKIP")
            return None

        b_e = match.get_op("cell_bias")
        b = get_weights_from_const_node(self.g, b_e)
        if not b or b.value.shape[0] != w.value.shape[1]:
            log.warning(
                "cell_kernel and cell_bias's dimensions does not match, skip")
            return None

        ft_bias = match.get_op("ft_bias")
        ft = get_weights_from_const_node(self.g, ft_bias)
        if not ft:
            return None

        if not b.dtype == ft.dtype:
            return None

        return RnnWeights(w, b, ft)
Exemplo n.º 2
0
    def get_weight_and_bias(self, match):
        # if one of them is not match, just return
        w_e = match.get_op("cell_kernel")
        w = get_weights_from_const_node(self.g, w_e)
        if not w:
            return None

        # check https://www.tensorflow.org/versions/r1.8/api_docs/cc/class/tensorflow/ops/bias-add
        # for bias_add data format
        bias_add = match.get_op("bias_add")
        if bias_add.data_format != "NHWC":
            log.debug("BiasAdd data_format is not NHWC, SKIP")
            return None

        b_e = match.get_op("cell_bias")
        b = get_weights_from_const_node(self.g, b_e)
        if not b or b.value.shape[0] != w.value.shape[1]:
            log.warning(
                "cell_kernel and cell_bias's dimensions does not match, skip")
            return None

        ft_bias = match.get_op("ft_bias")
        ft = get_weights_from_const_node(self.g, ft_bias)
        if not ft:
            return None

        if not (ft.value == 1 and self.g.get_dtype(b_e.output[0])
                == self.g.get_dtype(ft_bias.output[0])):
            return None

        return RnnWeights(w, b, ft)
    def _get_weight_and_bias_for_lstmblock_cell(self, context):
        cell_match = context.cell_match

        w_node = cell_match.get_op("cell_kernel")
        w = get_weights_from_const_node(self.g, w_node)
        if not w:
            log.warning("Cannot find weight, SKIP")
            return None

        b_node = cell_match.get_op("cell_bias")
        b = get_weights_from_const_node(self.g, b_node)
        if not b or b.value.shape[0] != w.value.shape[1]:
            log.warning(
                "cell_kernel and cell_bias's dimension doesn't match, SKIP")
            return None

        lstm_block_cell = cell_match.get_op("lstm_block_cell")
        ft_bias_val = np.array(lstm_block_cell.get_attr("forget_bias").f,
                               dtype=b.dtype)
        ft_bias = RnnWeight(None, ft_bias_val, ft_bias_val.dtype)

        return RnnWeights(w, b, ft_bias)