def _compare_shape_for_op(self, op1, op2):
     """Align outputs of op2 to op1."""
     for out1, out2 in zip(op1.outputs, op2.outputs):
         expected_shape = get_tf_tensor_shape(out1)
         if out1 is not None:
             actual_shape = get_tf_tensor_shape(out2)
             self.assertTrue(
                 utils.are_shapes_compatible(expected_shape, actual_shape))
Exemple #2
0
def set_shape_from_inputs_broadcast(input_tensors, output_tensor):
    s1 = get_tf_tensor_shape(input_tensors[0])
    s2 = get_tf_tensor_shape(input_tensors[1])
    new_shape = broadcast_shape_inference(s1, s2)
    if new_shape is not None:
        output_tensor.set_shape(new_shape)
        logger.debug("set [%s] with new shape %s", output_tensor.name, new_shape)
        return True
    return False
Exemple #3
0
def infer_input_shapes(op):
    if op.type in ["Select", "SelectV2"]:
        shape_t = get_tf_tensor_shape(op.inputs[1])
        shape_e = get_tf_tensor_shape(op.inputs[2])
        # copy shape if t OR e does not have a shape, no update if t AND e both have shapes
        if shape_t is None or shape_e is None:
            new_shape = shape_t or shape_e
            if new_shape is not None:
                op.inputs[1].set_shape(new_shape)
                op.inputs[2].set_shape(new_shape)
                logger.debug("set [%s, %s] with new shape %s", op.inputs[1].name, op.inputs[2].name, new_shape)
                return True
    return False
Exemple #4
0
def set_shape_from_input(input_tensor, output_tensor):
    new_shape = get_tf_tensor_shape(input_tensor)
    if new_shape is not None:
        output_tensor.set_shape(new_shape)
        logger.debug("set [%s] with new shape %s", output_tensor.name, new_shape)
        return True
    return False
def check_shape_for_tf_graph(tf_graph):
    """
    Check whether TF graph misses any shape,
    and return all ops with None shape outputs for TF graph.
    """
    op_outputs_mapping_none_shape = defaultdict(list)
    for op in tf_graph.get_operations():
        for out in op.outputs:
            if get_tf_tensor_shape(out) is None:
                op_outputs_mapping_none_shape[op.name].append(out.name)
    return op_outputs_mapping_none_shape
def get_output_shapes(node_def, input_dtypes, input_shapes, inp_consts):
    """Returns a list of the output shapes of an op. input_dtypes should be tf dtypes."""
    from tf2onnx.tf_loader import tf_session, tf_placeholder  # pylint: disable=import-outside-toplevel

    if node_def.op in ["Prelu", "Enter"]:
        return [input_shapes[0]]

    if node_def.op == "Merge":
        # Find the first non-None shape (if it exists) and return it
        non_none = ([t for t in input_shapes if t is not None] + [None])[0]
        # The second output of merge is a scalar int indicating which input was selected
        return [non_none, []]

    if node_def.op == "Placeholder":
        shape = None
        if 'shape' in node_def.attr:
            shape = [d.size for d in node_def.attr['shape'].shape.dim]
            shape = [None if d == -1 else d for d in shape]
            if len(shape) == 0:
                # According to TF docs, "If the shape has 0 dimensions, the shape is unconstrained."
                shape = None
        return [shape]

    del node_def.input[:]
    node_def.name = "node"
    if "_class" in node_def.attr:
        # Remove colocation information (list of nodes tf wants computed on same device)
        del node_def.attr["_class"]

    g = tf.Graph()
    with g.as_default():
        for i, (dtype, shape,
                const) in enumerate(zip(input_dtypes, input_shapes,
                                        inp_consts)):
            inp = "input" + str(i)
            if const is None:
                if shape is not None and -1 in shape:
                    shape = [d if d != -1 else None for d in shape]
                tf_placeholder(dtype, name=inp, shape=shape)
            else:
                tf.constant(const, dtype, name=inp)
            node_def.input.append(inp)
        mini_graph_def = g.as_graph_def()
        mini_graph_def.node.append(node_def)
    g2 = tf.Graph()
    with g2.as_default():
        with tf_session() as sess:
            tf.import_graph_def(mini_graph_def, name='')
            node = sess.graph.get_operation_by_name("node")
            outputs_shapes = [
                tf_utils.get_tf_tensor_shape(out) for out in node.outputs
            ]
            return outputs_shapes
def check_shape_for_tf_graph(tf_graph):
    """
    Check whether TF graph misses any shape,
    and return all ops with None shape outputs for TF graph.
    """
    skip_list = {'FusedBatchNormV3': 5}
    op_outputs_mapping_none_shape = defaultdict(list)
    for op in tf_graph.get_operations():
        for i, out in enumerate(op.outputs):
            if op.type in skip_list:
                if skip_list[op.type] == i:
                    continue
            if get_tf_tensor_shape(out) is None:
                op_outputs_mapping_none_shape[op.name].append(out.name)
    return op_outputs_mapping_none_shape
def infer_shape_for_op(op):
    has_unknown_output_shape = any(
        get_tf_tensor_shape(out) is None for out in op.outputs)

    if not has_unknown_output_shape:
        return False

    if op.type == "Placeholder":
        # if placeholder shape is not found, try to get it from "shape" attribute.
        attr_shape = get_tf_shape_attr(op)
        if attr_shape is not None:
            new_shape = list(attr_shape)
            op.outputs[0].set_shape(new_shape)
            logger.debug("set placeholder op [%s] with new shape %s",
                         op.outputs[0].name, new_shape)
            return True
        logger.warning(
            "Shape of placeholder %s is unknown, treated it as a scalar",
            op.name)
        op.outputs[0].set_shape([])
        return True

    if op.type == "Merge":
        s1 = get_tf_tensor_shape(op.inputs[0])
        s2 = get_tf_tensor_shape(op.inputs[1])
        new_shape = None
        if s1 is None and s2 is None:
            return False
        if s1 is None and s2 is not None:
            new_shape = s2
        if s1 is not None and s2 is None:
            new_shape = s1

        if new_shape is not None:
            op.inputs[0].set_shape(new_shape)
            op.inputs[1].set_shape(new_shape)
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            return True

        # inputs' shapes both exist
        if s1 != s2:
            if len(s1) != len(s2):
                logger.warning(
                    "Shapes of Merge %s have different ranks: %s, %s", op.name,
                    len(s1), len(s2))
                return False

            logger.debug(
                "Inputs of Merge %s have different shapes: %s, %s, but the same rank",
                op.name, s1, s2)
            new_shape = _merge_shapes_for_tf(s1, s2)
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
        else:
            new_shape = s1
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)

        return True

    if op.type == "Switch":
        new_shape = get_tf_tensor_shape(op.inputs[0])
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            op.outputs[1].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[1].name,
                         new_shape)
            return True
        return False

    if op.type == "Enter":
        new_shape = get_tf_tensor_shape(op.inputs[0])
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            return True
        return False

    if op.type == "TensorArrayGatherV3":
        # TensorArrayGatherV3's output: all of the elem in the TensorArray,
        # concatenated along a new axis (the new dimension 0), so shape of TensorArray should be found first.
        # And TensorArrayWrite will write elem to TensorArray, so shape of TensorArray can be got from TensorArrayWrite
        # so the process is: first find TensorArrayWrite and then get TensorArray's shape,
        # and finally add one dim to the shape is shape of TensorArrayGather

        handle_op = op.inputs[0].op
        if handle_op.type != "TensorArrayV3":
            return False

        # find TensorArrayWrite
        tensor_array_write_op = _find_tensorarray_write(handle_op)
        if not tensor_array_write_op:
            return False
        # get TensorArray shape from input tensor of the found TensorArrayWrite op
        shape = get_tf_tensor_shape(tensor_array_write_op.inputs[2])
        # update TensorArray's shape info
        if shape is not None:
            new_shape = [None] + shape
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            return True
        return False

    if op.type == "TensorArrayReadV3":
        # TensorArrayRead reads an element from the TensorArray into output value.
        # The TensorArray's shape can be got from TensorArrayScatter.
        # So the process is: first find TensorArrayScatter's shape and then TensorArray's
        # and finally take its last n-1 dim.
        flow_in_op = op.inputs[2].op
        if flow_in_op.type != "Enter":
            return False

        scatter_op = flow_in_op.inputs[0].op
        if scatter_op.type != "TensorArrayScatterV3":
            return False

        value_shape_before_scatter = get_tf_tensor_shape(scatter_op.inputs[2])
        if value_shape_before_scatter is None:
            return False

        new_shape = value_shape_before_scatter[1:]
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            return True
        return False

    return False
def infer_output_shapes_with_partial_inputs(op):
    # output shape of concat op: only the dim val of concatenated dim will be changed
    # so only partial(at least one) input shapes need to be known to infer output shape of concat op
    if utils.is_tf_concat_op(op):
        data_inputs = op.inputs[:-1]
        input_shapes = [get_tf_tensor_shape(inp) for inp in data_inputs]
        input_shapes = [shape for shape in input_shapes if shape is not None]
        if not input_shapes:
            logger.debug(
                "all input shapes of concat op %s are None, can't infer its output shape",
                op.name)
            return False

        new_shape = input_shapes[0]
        axis_op = op.inputs[-1]
        rank = len(new_shape)
        if not utils.is_tf_const_op(axis_op):
            op.outputs[0].set_shape([-1] * rank)
            return True

        axis = get_tf_const_value(axis_op)
        axis = axis if axis >= 0 else axis + rank
        new_shape[axis] = -1
        if len(input_shapes) == len(data_inputs):  # all input shapes are known
            concat_dim_vals = list(np.array(input_shapes)[:, axis])
            # only when inputs' shape are known, then val of concat dim can be calculated
            if concat_dim_vals.count(-1) == 0:
                new_shape[axis] = sum(concat_dim_vals)

        op.outputs[0].set_shape(new_shape)
        logger.debug("set Concat op [%s] with new shape %s",
                     op.outputs[0].name, new_shape)
        return True

    if op.type in ["Select", "SelectV2"]:
        new_shape = get_tf_tensor_shape(op.inputs[1])
        if new_shape is None:
            new_shape = get_tf_tensor_shape(op.inputs[2])
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            op.inputs[1].set_shape(new_shape)
            op.inputs[2].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            return True
        return False

    if op.type == "Pack":
        axis = op.get_attr("axis")
        input_shape = None
        for i in op.inputs:
            s = get_tf_tensor_shape(i)
            if s is not None:
                input_shape = s
                break
        if input_shape is None:
            return False
        if axis < 0:
            axis += len(input_shape)
        for i in op.inputs:
            if not get_tf_tensor_shape(i):
                i.set_shape(input_shape)
                logger.debug("set [%s] with new shape %s", i.name, input_shape)
        new_shape = input_shape[:axis] + [len(op.inputs)] + input_shape[axis:]
        op.outputs[0].set_shape(new_shape)
        logger.debug("set Pack op [%s] with new shape %s", op.outputs[0].name,
                     new_shape)
        return True

    if op.type == "Pow":
        # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/pow
        new_shape = get_tf_tensor_shape(op.inputs[0])
        if new_shape is None:
            new_shape = get_tf_tensor_shape(op.inputs[1])
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            return True
        return False

    return None
def infer_shape_for_op_legacy(op):
    # invoke tf shape inference first
    infer_shape_for_op(op)

    has_unknown_input_shape = any(
        get_tf_tensor_shape(inp) is None for inp in op.inputs)
    has_unknown_output_shape = any(
        get_tf_tensor_shape(out) is None for out in op.outputs)

    # an input shape may be inferred from op output or other input shapes
    # try to infer it first
    if has_unknown_input_shape:
        if infer_input_shapes(op):
            return True

    if not has_unknown_output_shape:
        return False

    # for those ops, we don't expect all input shapes available to infer output shapes.
    ret = infer_output_shapes_with_partial_inputs(op)
    if ret is not None:
        return ret

    # for ops, we need all input shapes ready to infer output shapes.
    are_all_input_shape_ready = True
    no_shape = []
    for i in op.inputs:
        if get_tf_tensor_shape(i) is None:
            are_all_input_shape_ready = False
            no_shape.append(i.name)

    if not are_all_input_shape_ready:
        logger.debug(
            "op %s has inputs don't have shape specified, they are: %s",
            op.name, no_shape)
        return False

    if op.type in direct_ops:
        return set_shape_from_input(op.inputs[0], op.outputs[0])

    if op.type in broadcast_ops:
        return set_shape_from_inputs_broadcast(op.inputs, op.outputs[0])

    if op.type == "RandomUniform":
        shape_op = op.inputs[0].op
        if not shape_op or shape_op.type != "Shape":
            return False
        return set_shape_from_input(shape_op.inputs[0], op.outputs[0])

    if op.type == "Gather":
        # uses the follwing link to know how to infer shape of output
        # https://www.tensorflow.org/api_docs/python/tf/gather
        shape_params = get_tf_tensor_shape(op.inputs[0])
        shape_indices = get_tf_tensor_shape(op.inputs[1])
        # gather can only have 2 inputs
        # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/gather.html
        if len(op.inputs) == 3:
            axis_op = op.inputs[2].op
            if not utils.is_tf_const_op(axis_op):
                return False
            axis = get_tf_const_value(axis_op)
        else:
            axis = 0

        shape = shape_params[:axis] + shape_indices + shape_params[axis + 1:]
        op.outputs[0].set_shape(shape)
        return True

    if op.type in ["All", "Any", "Max", "Min"]:
        axis_op = op.inputs[1].op
        if not utils.is_tf_const_op(axis_op):
            return False
        axis = get_tf_const_value(axis_op)
        if not isinstance(axis, list):
            axis = [axis]
        keep_dims = op.get_attr("keep_dims")
        shape = get_tf_tensor_shape(op.inputs[0])
        for i, _ in enumerate(axis):
            if axis[i] < 0:
                axis[i] += len(shape)

        new_shape = []
        for i, _ in enumerate(shape):
            if i in axis:
                if keep_dims:
                    new_shape.append(1)
            else:
                new_shape.append(shape[i])

        op.outputs[0].set_shape(new_shape)
        logger.debug("set %s op [%s] with new shape %s", op.type,
                     op.outputs[0].name, new_shape)
        return True

    if op.type == "ExpandDims":
        # https://www.tensorflow.org/api_docs/python/tf/expand_dims
        input_shape = get_tf_tensor_shape(op.inputs[0])
        dim_op = op.inputs[1].op
        if input_shape is None or not utils.is_tf_const_op(dim_op):
            return False

        dim = get_tf_const_value(dim_op)
        if dim < 0:
            dim = dim + len(input_shape) + 1

        new_shape = input_shape[:dim] + [1] + input_shape[dim:]
        op.outputs[0].set_shape(new_shape)
        logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                     new_shape)
        return True

    if op.type == "Unpack":
        input_shape = get_tf_tensor_shape(op.inputs[0])
        if input_shape is None:
            return False

        axis = op.get_attr("axis")
        axis = axis if axis >= 0 else axis + len(input_shape)
        # the link below says that the rank of output is "rank(input) -1",
        # from this statement "num" must equal to input_shape[axis], and if not tf will throw a runtime error
        # https://www.tensorflow.org/api_docs/python/tf/unstack
        new_shape = input_shape[:axis] + input_shape[axis + 1:]
        for output in op.outputs:
            output.set_shape(new_shape)
            logger.debug("set %s op [%s] with new shape %s", op.type,
                         output.name, new_shape)
        return True

    if op.type in ["Minimum", "Maximum"]:
        # ops that are elementwise and support broadcasting
        input_shapes = [get_tf_tensor_shape(op) for op in op.inputs]
        new_shape = broadcast_shape_inference(*input_shapes)
        op.outputs[0].set_shape(new_shape)
        return True

    return False