Exemple #1
0
 def _compare_shape_for_op(self, op1, op2):
     """Align outputs of op2 to op1."""
     for out1, out2 in zip(op1.outputs, op2.outputs):
         expected_shape = utils.get_tf_tensor_shape(out1)
         if out1 is not None:
             actual_shape = utils.get_tf_tensor_shape(out2)
             self.assertTrue(
                 utils.are_shapes_compatible(expected_shape, actual_shape))
def set_shape_from_inputs_broadcast(input_tensors, output_tensor):
    s1 = utils.get_tf_tensor_shape(input_tensors[0])
    s2 = utils.get_tf_tensor_shape(input_tensors[1])
    new_shape = broadcast_shape_inference(s1, s2)
    if new_shape is not None:
        output_tensor.set_shape(new_shape)
        logger.debug("set [%s] with new shape %s", output_tensor.name,
                     new_shape)
        return True
    return False
def infer_input_shapes(op):
    if op.type == "Select":
        shape_t = utils.get_tf_tensor_shape(op.inputs[1])
        shape_e = utils.get_tf_tensor_shape(op.inputs[2])
        # copy shape if t OR e does not have a shape, no update if t AND e both have shapes
        if shape_t is None or shape_e is None:
            new_shape = shape_t or shape_e
            if new_shape is not None:
                op.inputs[1].set_shape(new_shape)
                op.inputs[2].set_shape(new_shape)
                logger.debug("set [%s, %s] with new shape %s",
                             op.inputs[1].name, op.inputs[2].name, new_shape)
                return True
    return False
def set_shape_from_input(input_tensor, output_tensor):
    new_shape = utils.get_tf_tensor_shape(input_tensor)
    if new_shape is not None:
        output_tensor.set_shape(new_shape)
        logger.debug("set [%s] with new shape %s", output_tensor.name,
                     new_shape)
        return True
    return False
def check_shape_for_tf_graph(tf_graph):
    """
    Check whether TF graph misses any shape,
    and return all ops with None shape outputs for TF graph.
    """
    op_outputs_mapping_none_shape = defaultdict(list)
    for op in tf_graph.get_operations():
        for out in op.outputs:
            if utils.get_tf_tensor_shape(out) is None:
                op_outputs_mapping_none_shape[op.name].append(out.name)
    return op_outputs_mapping_none_shape
def infer_shape_for_op(op):
    has_unknown_output_shape = any(
        utils.get_tf_tensor_shape(out) is None for out in op.outputs)

    if not has_unknown_output_shape:
        return False

    if op.type == "Placeholder":
        # if placeholder shape is not found, try to get it from "shape" attribute.
        attr_shape = utils.get_tf_shape_attr(op)
        if attr_shape is not None:
            new_shape = list(attr_shape)
            op.outputs[0].set_shape(new_shape)
            logger.debug("set placeholder op [%s] with new shape %s",
                         op.outputs[0].name, new_shape)
            return True
        logger.warning(
            "Shape of placeholder %s is unknown, treated it as a scalar",
            op.name)
        op.outputs[0].set_shape([])
        return True

    if op.type == "Merge":
        s1 = utils.get_tf_tensor_shape(op.inputs[0])
        s2 = utils.get_tf_tensor_shape(op.inputs[1])
        new_shape = None
        if s1 is None and s2 is None:
            return False
        if s1 is None and s2 is not None:
            new_shape = s2
        if s1 is not None and s2 is None:
            new_shape = s1

        if new_shape is not None:
            op.inputs[0].set_shape(new_shape)
            op.inputs[1].set_shape(new_shape)
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            return True

        # inputs' shapes both exist
        if s1 != s2:
            if len(s1) != len(s2):
                logger.warning(
                    "Shapes of Merge %s have different ranks: %s, %s", op.name,
                    len(s1), len(s2))
                return False

            logger.debug(
                "Inputs of Merge %s have different shapes: %s, %s, but the same rank",
                op.name, s1, s2)
            new_shape = _merge_shapes_for_tf(s1, s2)
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
        else:
            new_shape = s1
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)

        return True

    if op.type == "Switch":
        new_shape = utils.get_tf_tensor_shape(op.inputs[0])
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            op.outputs[1].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[1].name,
                         new_shape)
            return True
        return False

    if op.type == "Enter":
        new_shape = utils.get_tf_tensor_shape(op.inputs[0])
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            return True
        return False

    if op.type == "TensorArrayGatherV3":
        # TensorArrayGatherV3's output: all of the elem in the TensorArray,
        # concatenated along a new axis (the new dimension 0), so shape of TensorArray should be found first.
        # And TensorArrayWrite will write elem to TensorArray, so shape of TensorArray can be got from TensorArrayWrite
        # so the process is: first find TensorArrayWrite and then get TensorArray's shape,
        # and finally add one dim to the shape is shape of TensorArrayGather

        handle_op = op.inputs[0].op
        if handle_op.type != "TensorArrayV3":
            return False

        # find TensorArrayWrite
        tensor_array_write_op = _find_tensorarray_write(handle_op)
        if not tensor_array_write_op:
            return False
        # get TensorArray shape from input tensor of the found TensorArrayWrite op
        shape = utils.get_tf_tensor_shape(tensor_array_write_op.inputs[2])
        # update TensorArray's shape info
        if shape is not None:
            new_shape = [None] + shape
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            return True
        return False

    if op.type == "TensorArrayReadV3":
        # TensorArrayRead reads an element from the TensorArray into output value.
        # The TensorArray's shape can be got from TensorArrayScatter.
        # So the process is: first find TensorArrayScatter's shape and then TensorArray's
        # and finally take its last n-1 dim.
        flow_in_op = op.inputs[2].op
        if flow_in_op.type != "Enter":
            return False

        scatter_op = flow_in_op.inputs[0].op
        if scatter_op.type != "TensorArrayScatterV3":
            return False

        value_shape_before_scatter = utils.get_tf_tensor_shape(
            scatter_op.inputs[2])
        if value_shape_before_scatter is None:
            return False

        new_shape = value_shape_before_scatter[1:]
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            return True
        return False

    return False
def infer_output_shapes_with_partial_inputs(op):
    # output shape of concat op: only the dim val of concatenated dim will be changed
    # so only partial(at least one) input shapes need to be known to infer output shape of concat op
    if utils.is_tf_concat_op(op):
        data_inputs = op.inputs[:-1]
        input_shapes = [utils.get_tf_tensor_shape(inp) for inp in data_inputs]
        input_shapes = [shape for shape in input_shapes if shape is not None]
        if not input_shapes:
            logger.debug(
                "all input shapes of concat op %s are None, can't infer its output shape",
                op.name)
            return False

        new_shape = input_shapes[0]
        axis_op = op.inputs[-1]
        rank = len(new_shape)
        if not utils.is_tf_const_op(axis_op):
            op.outputs[0].set_shape([-1] * rank)
            return True

        axis = utils.get_tf_const_value(axis_op)
        axis = axis if axis >= 0 else axis + rank
        new_shape[axis] = -1
        if len(input_shapes) == len(data_inputs):  # all input shapes are known
            concat_dim_vals = list(np.array(input_shapes)[:, axis])
            # only when inputs' shape are known, then val of concat dim can be calculated
            if concat_dim_vals.count(-1) == 0:
                new_shape[axis] = sum(concat_dim_vals)

        op.outputs[0].set_shape(new_shape)
        logger.debug("set Concat op [%s] with new shape %s",
                     op.outputs[0].name, new_shape)
        return True

    if op.type == "Select":
        new_shape = utils.get_tf_tensor_shape(op.inputs[1])
        if new_shape is None:
            new_shape = utils.get_tf_tensor_shape(op.inputs[2])
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            op.inputs[1].set_shape(new_shape)
            op.inputs[2].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            return True
        return False

    if op.type == "Pack":
        axis = op.get_attr("axis")
        input_shape = None
        for i in op.inputs:
            s = utils.get_tf_tensor_shape(i)
            if s is not None:
                input_shape = s
                break
        if input_shape is None:
            return False
        if axis < 0:
            axis += len(input_shape)
        for i in op.inputs:
            if not utils.get_tf_tensor_shape(i):
                i.set_shape(input_shape)
                logger.debug("set [%s] with new shape %s", i.name, input_shape)
        new_shape = input_shape[:axis] + [len(op.inputs)] + input_shape[axis:]
        op.outputs[0].set_shape(new_shape)
        logger.debug("set Pack op [%s] with new shape %s", op.outputs[0].name,
                     new_shape)
        return True

    if op.type == "Pow":
        # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/pow
        new_shape = utils.get_tf_tensor_shape(op.inputs[0])
        if new_shape is None:
            new_shape = utils.get_tf_tensor_shape(op.inputs[1])
        if new_shape is not None:
            op.outputs[0].set_shape(new_shape)
            logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                         new_shape)
            return True
        return False

    return None
def infer_shape_for_op_legacy(op):
    # invoke tf shape inference first
    infer_shape_for_op(op)

    has_unknown_input_shape = any(
        utils.get_tf_tensor_shape(inp) is None for inp in op.inputs)
    has_unknown_output_shape = any(
        utils.get_tf_tensor_shape(out) is None for out in op.outputs)

    # an input shape may be inferred from op output or other input shapes
    # try to infer it first
    if has_unknown_input_shape:
        if infer_input_shapes(op):
            return True

    if not has_unknown_output_shape:
        return False

    # for those ops, we don't expect all input shapes available to infer output shapes.
    ret = infer_output_shapes_with_partial_inputs(op)
    if ret is not None:
        return ret

    # for ops, we need all input shapes ready to infer output shapes.
    are_all_input_shape_ready = True
    no_shape = []
    for i in op.inputs:
        if utils.get_tf_tensor_shape(i) is None:
            are_all_input_shape_ready = False
            no_shape.append(i.name)

    if not are_all_input_shape_ready:
        logger.debug(
            "op %s has inputs don't have shape specified, they are: %s",
            op.name, no_shape)
        return False

    if op.type in direct_ops:
        return set_shape_from_input(op.inputs[0], op.outputs[0])

    if op.type in broadcast_ops:
        return set_shape_from_inputs_broadcast(op.inputs, op.outputs[0])

    if op.type == "RandomUniform":
        shape_op = op.inputs[0].op
        if not shape_op or shape_op.type != "Shape":
            return False
        return set_shape_from_input(shape_op.inputs[0], op.outputs[0])

    if op.type == "Gather":
        # uses the follwing link to know how to infer shape of output
        # https://www.tensorflow.org/api_docs/python/tf/gather
        shape_params = utils.get_tf_tensor_shape(op.inputs[0])
        shape_indices = utils.get_tf_tensor_shape(op.inputs[1])
        # gather can only have 2 inputs
        # https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/gather.html
        if len(op.inputs) == 3:
            axis_op = op.inputs[2].op
            if not utils.is_tf_const_op(axis_op):
                return False
            axis = utils.get_tf_const_value(axis_op)
        else:
            axis = 0

        shape = shape_params[:axis] + shape_indices + shape_params[axis + 1:]
        op.outputs[0].set_shape(shape)
        return True

    if op.type in ["All", "Any", "Max", "Min"]:
        axis_op = op.inputs[1].op
        if not utils.is_tf_const_op(axis_op):
            return False
        axis = utils.get_tf_const_value(axis_op)
        if not isinstance(axis, list):
            axis = [axis]
        keep_dims = op.get_attr("keep_dims")
        shape = utils.get_tf_tensor_shape(op.inputs[0])
        for i, _ in enumerate(axis):
            if axis[i] < 0:
                axis[i] += len(shape)

        new_shape = []
        for i, _ in enumerate(shape):
            if i in axis:
                if keep_dims:
                    new_shape.append(1)
            else:
                new_shape.append(shape[i])

        op.outputs[0].set_shape(new_shape)
        logger.debug("set %s op [%s] with new shape %s", op.type,
                     op.outputs[0].name, new_shape)
        return True

    if op.type == "ExpandDims":
        # https://www.tensorflow.org/api_docs/python/tf/expand_dims
        input_shape = utils.get_tf_tensor_shape(op.inputs[0])
        dim_op = op.inputs[1].op
        if input_shape is None or not utils.is_tf_const_op(dim_op):
            return False

        dim = utils.get_tf_const_value(dim_op)
        if dim < 0:
            dim = dim + len(input_shape) + 1

        new_shape = input_shape[:dim] + [1] + input_shape[dim:]
        op.outputs[0].set_shape(new_shape)
        logger.debug("set [%s] with new shape %s", op.outputs[0].name,
                     new_shape)
        return True

    if op.type == "Unpack":
        input_shape = utils.get_tf_tensor_shape(op.inputs[0])
        if input_shape is None:
            return False

        axis = op.get_attr("axis")
        axis = axis if axis >= 0 else axis + len(input_shape)
        # the link below says that the rank of output is "rank(input) -1",
        # from this statement "num" must equal to input_shape[axis], and if not tf will throw a runtime error
        # https://www.tensorflow.org/api_docs/python/tf/unstack
        new_shape = input_shape[:axis] + input_shape[axis + 1:]
        for output in op.outputs:
            output.set_shape(new_shape)
            logger.debug("set %s op [%s] with new shape %s", op.type,
                         output.name, new_shape)
        return True

    if op.type in ["Minimum", "Maximum"]:
        # ops that are elementwise and support broadcasting
        input_shapes = [utils.get_tf_tensor_shape(op) for op in op.inputs]
        new_shape = broadcast_shape_inference(*input_shapes)
        op.outputs[0].set_shape(new_shape)
        return True

    return False
Exemple #9
0
def write_nodes(ops, filename):
    writer = open(filename, 'w+')
    list_output_nodes = []

    for node in ops:
        attr_dictionary = dict(node.node_def.attr.items())

        dict_node = {}
        dict_node['op_name'] = node.name
        # try:
        # 	dict_node['dtype'] = utils.map_tf_dtype(utils.get_tf_node_attr(node, "dtype"))
        # except:
        # 	dict_node['dtype'] = None

        output_list = []
        for node_output in node.outputs:
            output_dict = {}
            output_dict['name'] = node_output.name.split(':')[0]
            shape = utils.get_tf_tensor_shape(node_output)
            if (len(shape) > 0) and (shape[0] is None):
                shape[0] = -1
            output_dict['shape'] = shape

            if (node_output.dtype.name is not None):
                output_dict['dtype'] = node_output.dtype.name
            else:
                output_dict['dtype'] = None

            output_list.append(output_dict)
        dict_node['output'] = output_list
        #

        input_list = []
        for node_input in node.inputs:
            input_dict = {}
            input_dict['name'] = node_input.name.split(':')[0]
            shape = utils.get_tf_tensor_shape(node_input)
            if (len(shape) > 0) and (shape[0] is None):
                shape[0] = -1
            input_dict['shape'] = shape

            if (node_input.dtype.name is not None):
                input_dict['dtype'] = node_input.dtype.name
            else:
                input_dict['dtype'] = None

            input_list.append(input_dict)

        dict_node['inputs'] = input_list
        dict_node['operator_name'] = node.type
        # print(attr_dictionary.keys())

        ## Code for extracting attributes in graph
        if ('padding' in attr_dictionary.keys()):
            padding = attr_dictionary['padding'].s
            dict_node['padding'] = padding.decode("utf-8")
        else:
            dict_node['padding'] = "None"

        if ('strides' in attr_dictionary.keys()):
            strides = attr_dictionary['strides'].list.i
            strides_list = [int(a) for a in strides]
            dict_node['strides'] = strides_list
        else:
            dict_node['strides'] = "None"

        if ('dilations' in attr_dictionary.keys()):
            dilations = attr_dictionary['dilations'].list.i
            dilations_list = [int(a) for a in dilations]
            dict_node['dilations'] = dilations_list
        else:
            dict_node['dilations'] = "None"

        list_output_nodes.append(dict_node)

    outstr = str(list_output_nodes)
    outstr = outstr.replace("\'", "\"")
    parsed_json = json.loads(outstr)
    print(json.dumps(parsed_json, indent=4, sort_keys=False), file=writer)
    # print(list_output_nodes, file=writer)
    writer.close()