def infer_output_shapes_with_partial_inputs(op):
    # output shape of concat op: only the dim val of concatenated dim will be changed
    # so only partial(at least one) input shapes need to be known to infer output shape of concat op
    if utils.is_tf_concat_op(op):
        data_inputs = op.inputs[:-1]
        input_shapes = [get_tf_tensor_shape(inp) for inp in data_inputs]
        input_shapes = [shape for shape in input_shapes if shape is not None]
        if not input_shapes:
            logger.debug(
                "all input shapes of concat op %s are None, can't infer its output shape",
                op.name)
            return False

        new_shape = input_shapes[0]
        axis_op = op.inputs[-1]
        rank = len(new_shape)
        if not utils.is_tf_const_op(axis_op):
            op.outputs[0].set_shape([-1] * rank)
            return True

        axis = get_tf_const_value(axis_op)
        axis = axis if axis >= 0 else axis + rank
        new_shape[axis] = -1
        if len(input_shapes) == len(data_inputs):  # all input shapes are known
            concat_dim_vals = list(np.array(input_shapes)[:, axis])
            # only when inputs' shape are known, then val of concat dim can be calculated
            if concat_dim_vals.count(-1) == 0:
                new_shape[axis] = sum(concat_dim_vals)

        op.outputs[0].set_shape(new_shape)
        logger.debug("set Concat op [%s] with new shape %s",
                     op.outputs[0].name, new_shape)
        return True

    if op.type in ["Select", "SelectV2"]:
        new_shape = get_tf_tensor_shape(op.inputs[1])
        if new_shape is None:
            new_shape = get_tf_tensor_shape(op.inputs[2])
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            op.inputs[1].set_shape(new_shape)
            op.inputs[2].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            return True
        return False

    if op.type == "Pack":
        axis = op.get_attr("axis")
        input_shape = None
        for i in op.inputs:
            s = get_tf_tensor_shape(i)
            if s is not None:
                input_shape = s
                break
        if input_shape is None:
            return False
        if axis < 0:
            axis += len(input_shape)
        for i in op.inputs:
            if not get_tf_tensor_shape(i):
                i.set_shape(input_shape)
                logger.debug("set [%s] with new shape %s", i.name, input_shape)
        new_shape = input_shape[:axis] + [len(op.inputs)] + input_shape[axis:]
        op.outputs[0].set_shape(new_shape)
        logger.debug("set Pack op [%s] with new shape %s", op.outputs[0].name,
                     new_shape)
        return True

    if op.type == "Pow":
        # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/pow
        new_shape = get_tf_tensor_shape(op.inputs[0])
        if new_shape is None:
            new_shape = get_tf_tensor_shape(op.inputs[1])
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            return True
        return False

    return None
def infer_shape_for_op_legacy(op):
    # invoke tf shape inference first
    infer_shape_for_op(op)

    has_unknown_input_shape = any(
        get_tf_tensor_shape(inp) is None for inp in op.inputs)
    has_unknown_output_shape = any(
        get_tf_tensor_shape(out) is None for out in op.outputs)

    # an input shape may be inferred from op output or other input shapes
    # try to infer it first
    if has_unknown_input_shape:
        if infer_input_shapes(op):
            return True

    if not has_unknown_output_shape:
        return False

    # for those ops, we don't expect all input shapes available to infer output shapes.
    ret = infer_output_shapes_with_partial_inputs(op)
    if ret is not None:
        return ret

    # for ops, we need all input shapes ready to infer output shapes.
    are_all_input_shape_ready = True
    no_shape = []
    for i in op.inputs:
        if get_tf_tensor_shape(i) is None:
            are_all_input_shape_ready = False
            no_shape.append(i.name)

    if not are_all_input_shape_ready:
        logger.debug(
            "op %s has inputs don't have shape specified, they are: %s",
            op.name, no_shape)
        return False

    if op.type in direct_ops:
        return set_shape_from_input(op.inputs[0], op.outputs[0])

    if op.type in broadcast_ops:
        return set_shape_from_inputs_broadcast(op.inputs, op.outputs[0])

    if op.type == "RandomUniform":
        shape_op = op.inputs[0].op
        if not shape_op or shape_op.type != "Shape":
            return False
        return set_shape_from_input(shape_op.inputs[0], op.outputs[0])

    if op.type == "Gather":
        # uses the follwing link to know how to infer shape of output
        # https://www.tensorflow.org/api_docs/python/tf/gather
        shape_params = get_tf_tensor_shape(op.inputs[0])
        shape_indices = get_tf_tensor_shape(op.inputs[1])
        # gather can only have 2 inputs
        # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/gather.html
        if len(op.inputs) == 3:
            axis_op = op.inputs[2].op
            if not utils.is_tf_const_op(axis_op):
                return False
            axis = get_tf_const_value(axis_op)
        else:
            axis = 0

        shape = shape_params[:axis] + shape_indices + shape_params[axis + 1:]
        op.outputs[0].set_shape(shape)
        return True

    if op.type in ["All", "Any", "Max", "Min"]:
        axis_op = op.inputs[1].op
        if not utils.is_tf_const_op(axis_op):
            return False
        axis = get_tf_const_value(axis_op)
        if not isinstance(axis, list):
            axis = [axis]
        keep_dims = op.get_attr("keep_dims")
        shape = get_tf_tensor_shape(op.inputs[0])
        for i, _ in enumerate(axis):
            if axis[i] < 0:
                axis[i] += len(shape)

        new_shape = []
        for i, _ in enumerate(shape):
            if i in axis:
                if keep_dims:
                    new_shape.append(1)
            else:
                new_shape.append(shape[i])

        op.outputs[0].set_shape(new_shape)
        logger.debug("set %s op [%s] with new shape %s", op.type,
                     op.outputs[0].name, new_shape)
        return True

    if op.type == "ExpandDims":
        # https://www.tensorflow.org/api_docs/python/tf/expand_dims
        input_shape = get_tf_tensor_shape(op.inputs[0])
        dim_op = op.inputs[1].op
        if input_shape is None or not utils.is_tf_const_op(dim_op):
            return False

        dim = get_tf_const_value(dim_op)
        if dim < 0:
            dim = dim + len(input_shape) + 1

        new_shape = input_shape[:dim] + [1] + input_shape[dim:]
        op.outputs[0].set_shape(new_shape)
        logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                     new_shape)
        return True

    if op.type == "Unpack":
        input_shape = get_tf_tensor_shape(op.inputs[0])
        if input_shape is None:
            return False

        axis = op.get_attr("axis")
        axis = axis if axis >= 0 else axis + len(input_shape)
        # the link below says that the rank of output is "rank(input) -1",
        # from this statement "num" must equal to input_shape[axis], and if not tf will throw a runtime error
        # https://www.tensorflow.org/api_docs/python/tf/unstack
        new_shape = input_shape[:axis] + input_shape[axis + 1:]
        for output in op.outputs:
            output.set_shape(new_shape)
            logger.debug("set %s op [%s] with new shape %s", op.type,
                         output.name, new_shape)
        return True

    if op.type in ["Minimum", "Maximum"]:
        # ops that are elementwise and support broadcasting
        input_shapes = [get_tf_tensor_shape(op) for op in op.inputs]
        new_shape = broadcast_shape_inference(*input_shapes)
        op.outputs[0].set_shape(new_shape)
        return True

    return False