def _ct_ht_shared_variable_finder(self, context, i):
        if self.lstm_cell_type == RNNUnitType.LSTMBlockCell:
            return None

        lstm_cell = context.cell_match[i]
        ct = lstm_cell.get_op("ct").output[0]
        ht = lstm_cell.get_op("ht").output[0]
        ct_concat = [c for c in self.g.find_output_consumers(ct) if is_tf_concat_op(c)]
        ht_concat = [c for c in self.g.find_output_consumers(ht) if is_tf_concat_op(c)]
        if len(ct_concat) != 1 or len(ht_concat) != 1 or ct_concat[0] != ht_concat[0]:
            logger.debug("failed to find ct-ht concat")
            return None
        ct_ht_shared_output = ct_concat[0].output[0]

        consumers = []
        ct_identity_consumer = lstm_cell.get_op("ct_identity_consumer")
        ht_identity_consumer = lstm_cell.get_op("xh")
        ct_slice = [c for c in ct_identity_consumer.inputs if is_tf_slice_op(c)]
        ht_slice = [c for c in ht_identity_consumer.inputs if is_tf_slice_op(c)]
        if len(ct_slice) != 1 or len(ht_slice) != 1:
            logger.debug("failed to find slice op before identity consumers")
            return None
        consumers.extend([ct_slice[0], ht_slice[0]])

        return self._find_state_variable_with_select(
            context,
            ct_ht_shared_output,
            consumers
        )
def infer_output_shapes_with_partial_inputs(op):
    # output shape of concat op: only the dim val of concatenated dim will be changed
    # so only partial(at least one) input shapes need to be known to infer output shape of concat op
    if utils.is_tf_concat_op(op):
        data_inputs = op.inputs[:-1]
        input_shapes = [get_tf_tensor_shape(inp) for inp in data_inputs]
        input_shapes = [shape for shape in input_shapes if shape is not None]
        if not input_shapes:
            logger.debug(
                "all input shapes of concat op %s are None, can't infer its output shape",
                op.name)
            return False

        new_shape = input_shapes[0]
        axis_op = op.inputs[-1]
        rank = len(new_shape)
        if not utils.is_tf_const_op(axis_op):
            op.outputs[0].set_shape([-1] * rank)
            return True

        axis = get_tf_const_value(axis_op)
        axis = axis if axis >= 0 else axis + rank
        new_shape[axis] = -1
        if len(input_shapes) == len(data_inputs):  # all input shapes are known
            concat_dim_vals = list(np.array(input_shapes)[:, axis])
            # only when inputs' shape are known, then val of concat dim can be calculated
            if concat_dim_vals.count(-1) == 0:
                new_shape[axis] = sum(concat_dim_vals)

        op.outputs[0].set_shape(new_shape)
        logger.debug("set Concat op [%s] with new shape %s",
                     op.outputs[0].name, new_shape)
        return True

    if op.type in ["Select", "SelectV2"]:
        new_shape = get_tf_tensor_shape(op.inputs[1])
        if new_shape is None:
            new_shape = get_tf_tensor_shape(op.inputs[2])
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            op.inputs[1].set_shape(new_shape)
            op.inputs[2].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            return True
        return False

    if op.type == "Pack":
        axis = op.get_attr("axis")
        input_shape = None
        for i in op.inputs:
            s = get_tf_tensor_shape(i)
            if s is not None:
                input_shape = s
                break
        if input_shape is None:
            return False
        if axis < 0:
            axis += len(input_shape)
        for i in op.inputs:
            if not get_tf_tensor_shape(i):
                i.set_shape(input_shape)
                logger.debug("set [%s] with new shape %s", i.name, input_shape)
        new_shape = input_shape[:axis] + [len(op.inputs)] + input_shape[axis:]
        op.outputs[0].set_shape(new_shape)
        logger.debug("set Pack op [%s] with new shape %s", op.outputs[0].name,
                     new_shape)
        return True

    if op.type == "Pow":
        # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/pow
        new_shape = get_tf_tensor_shape(op.inputs[0])
        if new_shape is None:
            new_shape = get_tf_tensor_shape(op.inputs[1])
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            return True
        return False

    return None