예제 #1
0
 def test_unique_graph(self):
     """Test for check_graphs and get_unique_graph."""
     g0 = ops_lib.Graph()
     with g0.as_default():
         a0 = constant_op.constant(1)
         b0 = constant_op.constant(2)
     g1 = ops_lib.Graph()
     with g1.as_default():
         a1 = constant_op.constant(1)
         b1 = constant_op.constant(2)
     # Same graph, should be fine.
     self.assertIsNone(op_selector.check_graphs(a0, b0))
     # Two different graphs, should assert.
     with self.assertRaises(ValueError):
         op_selector.check_graphs(a0, b0, a1, b1)
     # a0 and b0 belongs to the same graph, should be fine.
     self.assertEqual(op_selector.get_unique_graph([a0, b0]), g0)
     # Different graph, should raise an error.
     with self.assertRaises(ValueError):
         op_selector.get_unique_graph([a0, b0, a1, b1])
예제 #2
0
    def test_unique_graph_func_graph(self):
        """Test for get_unique_graph with FuncGraph."""
        outer = ops_lib.Graph()
        with outer.as_default():
            k1 = constant_op.constant(1)
            inner = func_graph.FuncGraph("inner")
            inner._graph_key = outer._graph_key
            with inner.as_default():
                k2 = constant_op.constant(2)

        unique_graph = op_selector.get_unique_graph([k1, k2])
        self.assertEqual(unique_graph._graph_key, inner._graph_key)
예제 #3
0
def _ensure_servable(input_tensors, names_to_output_tensor_infos):
  """Check that the signature outputs don't depend on unreachable placeholders.

  Args:
    input_tensors: An iterable of `Tensor`s specified as the signature's inputs.
    names_to_output_tensor_infos: An mapping from output names to respective
      `TensorInfo`s corresponding to the signature's output tensors.

  Raises:
    ValueError: If any of the signature's outputs depend on placeholders not
      provided as signature's inputs.
  """
  plain_input_tensors = nest.flatten(input_tensors, expand_composites=True)

  graph = op_selector.get_unique_graph(plain_input_tensors)

  output_tensors = [
      utils.get_tensor_from_tensor_info(tensor, graph=graph)
      for tensor in names_to_output_tensor_infos.values()
  ]
  plain_output_tensors = nest.flatten(output_tensors, expand_composites=True)

  dependency_ops = op_selector.get_backward_walk_ops(
      plain_output_tensors, stop_at_ts=plain_input_tensors)

  fed_tensors = object_identity.ObjectIdentitySet(plain_input_tensors)
  for dependency_op in dependency_ops:
    if _must_be_fed(dependency_op) and (not all(
        output in fed_tensors for output in dependency_op.outputs)):
      input_tensor_names = [tensor.name for tensor in plain_input_tensors]
      output_tensor_keys = list(names_to_output_tensor_infos.keys())
      output_tensor_names = [tensor.name for tensor in plain_output_tensors]
      dependency_path = op_selector.show_path(dependency_op,
                                              plain_output_tensors,
                                              plain_input_tensors)
      raise ValueError(
          f'The signature\'s input tensors {input_tensor_names} are '
          f'insufficient to compute its output keys {output_tensor_keys} '
          f'(respectively, tensors {output_tensor_names}) because of the '
          f'dependency on `{dependency_op.name}` which is not given as '
          'a signature input, as illustrated by the following dependency path: '
          f'{dependency_path}')