Exemple #1
0
    def compute_cost_graph(self, devices=None):
        """Computes a CostGraphDef protobuf based on this graph.

    Defined in tensorflow/core/framework/cost_graph.proto.

    Args:
      devices: optional [string], the names of devices to consider. If
          specified, any tensor on a device not listed is given a size of zero.
          Any device-less tensor (e.g. Mesh TensorFlow tensor) is not affected.

    Returns:
      a CostGraphDef protobuf with a Node for every operation in the graph, each
      of which is populated with size/dtype information for its inputs and
      outputs (which match the input/output order of the operation).
    """
        cost_graph_def = cost_graph_pb2.CostGraphDef()

        for i, operation_name in enumerate(self.get_all_operation_names()):
            node = cost_graph_def.node.add(
                name=operation_name,
                device=self.get_operation_device(operation_name),
                id=i)
            for input_name in self.get_operation_input_names(operation_name):
                id1, id2 = self._tensor_name_to_ids[input_name]
                node.input_info.add(preceding_node=id1, preceding_port=id2)

            for output_name in self.get_operation_output_names(operation_name):
                tensor_device = self.get_tensor_device(output_name)
                # devices = [] is not the same as None, and tensor_device = '' is also
                # not the same as None.
                if devices is None or tensor_device is None or tensor_device in devices:
                    node.output_info.add(
                        size=self.get_tensor_size(output_name),
                        alias_input_port=-1,
                        dtype=self.get_tensor_dtype(
                            output_name).as_datatype_enum,
                        shape=self.get_tensor_shape(output_name).as_proto(),
                    )
                else:
                    node.output_info.add(
                        size=0,
                        alias_input_port=-1,
                        dtype=self.get_tensor_dtype(
                            output_name).as_datatype_enum,
                    )

                # NOTE(joshuawang): Unfortunately, the CostGraphDef protobuf has final
                # operations, not tensors. As a result, we have to declare any operation
                # that outputs a final tensor as final, which may expand the final set
                # of tensors to keep in memory. This issue also arises in the scheduler
                # code we will interface with.
                if self.is_tensor_final(output_name):
                    node.is_final = True

        return cost_graph_def
    def StripCostGraphDef(self, cost_graph, to_strip):
        """Removes fields from a CostGraphDef protobuf.

    Helper method to reduce the initialization of CostGraphDef(s).

    Args:
      cost_graph: a CostGraphDef to strip.
      to_strip: a string, either "SIZES" or "DEVICES".

    Returns:
      a new CostGraphDef with either size information or device information
          stripped, as appropriate.
    """
        new_cost_graph = cost_graph_pb2.CostGraphDef()
        new_cost_graph.CopyFrom(cost_graph)
        for node in new_cost_graph.node:
            if to_strip == "SIZES":
                for output_info in node.output_info:
                    output_info.size = 0
                    output_info.ClearField("shape")
            if to_strip == "DEVICES":
                node.ClearField("device")
        return new_cost_graph
    def setUp(self):
        super(GraphInterfaceTest, self).setUp()
        self._cost_graph = cost_graph_pb2.CostGraphDef(node=[
            cost_graph_pb2.CostGraphDef.Node(
                name="X",
                device="/device:CPU:0",
                id=0,
                output_info=[
                    cost_graph_pb2.CostGraphDef.Node.OutputInfo(
                        size=48,
                        alias_input_port=-1,
                        dtype=types_pb2.DT_INT32,
                        shape=tensor_shape_pb2.TensorShapeProto(dim=[
                            tensor_shape_pb2.TensorShapeProto.Dim(size=3),
                            tensor_shape_pb2.TensorShapeProto.Dim(size=4),
                        ])),
                ],
            ),
            cost_graph_pb2.CostGraphDef.Node(
                name="Y",
                device="/device:CPU:0",
                id=1,
                output_info=[
                    cost_graph_pb2.CostGraphDef.Node.OutputInfo(
                        size=80,
                        alias_input_port=-1,
                        dtype=types_pb2.DT_INT32,
                        shape=tensor_shape_pb2.TensorShapeProto(dim=[
                            tensor_shape_pb2.TensorShapeProto.Dim(size=4),
                            tensor_shape_pb2.TensorShapeProto.Dim(size=5),
                        ])),
                ],
            ),
            cost_graph_pb2.CostGraphDef.Node(
                name="Z1",
                device="/device:CPU:0",
                id=2,
                input_info=[
                    cost_graph_pb2.CostGraphDef.Node.InputInfo(
                        preceding_node=0,
                        preceding_port=0,
                    ),
                    cost_graph_pb2.CostGraphDef.Node.InputInfo(
                        preceding_node=1,
                        preceding_port=0,
                    ),
                ],
                output_info=[
                    cost_graph_pb2.CostGraphDef.Node.OutputInfo(
                        size=60,
                        alias_input_port=-1,
                        dtype=types_pb2.DT_INT32,
                        shape=tensor_shape_pb2.TensorShapeProto(dim=[
                            tensor_shape_pb2.TensorShapeProto.Dim(size=3),
                            tensor_shape_pb2.TensorShapeProto.Dim(size=5),
                        ])),
                ],
                is_final=True,
            ),
            cost_graph_pb2.CostGraphDef.Node(
                name="Z2",
                device="/device:CPU:0",
                id=3,
                input_info=[
                    cost_graph_pb2.CostGraphDef.Node.InputInfo(
                        preceding_node=0,
                        preceding_port=0,
                    ),
                    cost_graph_pb2.CostGraphDef.Node.InputInfo(
                        preceding_node=1,
                        preceding_port=0,
                    ),
                ],
                output_info=[
                    cost_graph_pb2.CostGraphDef.Node.OutputInfo(
                        size=60,
                        alias_input_port=-1,
                        dtype=types_pb2.DT_INT32,
                        shape=tensor_shape_pb2.TensorShapeProto(dim=[
                            tensor_shape_pb2.TensorShapeProto.Dim(size=3),
                            tensor_shape_pb2.TensorShapeProto.Dim(size=5),
                        ])),
                ],
            ),
        ])
        self._sizeless_cost_graph = self.StripCostGraphDef(
            self._cost_graph, "SIZES")
        self._deviceless_cost_graph = self.StripCostGraphDef(
            self._cost_graph, "DEVICES")

        self._cost_graph_string = self._cost_graph.SerializeToString()
        self._sizeless_cost_graph_string = (
            self._sizeless_cost_graph.SerializeToString())
        self._deviceless_cost_graph_string = (
            self._deviceless_cost_graph.SerializeToString())