Esempio n. 1
0
 def __init__(self,
              address: str,
              name: str = "",
              list_registered_methods=False,
              timeout_in_ms=0):
     self._client_handle, methods = gen_rpc_ops.rpc_client(
         shared_name=name,
         server_address=address,
         list_registered_methods=list_registered_methods,
         timeout_in_ms=timeout_in_ms)
     if context.executing_eagerly():
         self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
             handle=self._client_handle,
             handle_device=self._client_handle.device)
     else:
         raise NotImplementedError(
             "Client creation is supported only in eager mode.")
     self._server_address = address
     self._method_registry = {}
     for method in methods.numpy():
         m = rpc_pb2.RegisteredMethod()
         m.ParseFromString(method)
         output_specs = nested_structure_coder.decode_proto(m.output_specs)
         input_specs = nested_structure_coder.decode_proto(m.input_specs)
         self._method_registry[m.method] = output_specs
         # TODO(ishark): Perhaps doc string can also be taken as input during
         # function registration.
         doc_string = "RPC Call for " + m.method + " method to server " + address
         self._add_method(m.method, output_specs, input_specs,
                          self._client_handle, doc_string)
 def testDecodeUnknownTensorSpec(self):
     encoded = struct_pb2.StructuredValue()
     encoded.type_spec_value.type_spec_class = 0
     encoded.type_spec_value.type_spec_class_name = "FutureTensorSpec"
     with self.assertRaisesRegex(
             ValueError, "The type 'FutureTensorSpec' is not supported"):
         nested_structure_coder.decode_proto(encoded)
 def testEncodeDecodeBoundedTensorSpecNoName(self):
     structure = [
         tensor_spec.BoundedTensorSpec((28, 28, 3), dtypes.float64, -2,
                                       (1, 1, 20))
     ]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected_list = expected.list_value
     expected_tensor_spec = expected_list.values.add(
     ).bounded_tensor_spec_value
     expected_tensor_spec.shape.dim.add().size = 28
     expected_tensor_spec.shape.dim.add().size = 28
     expected_tensor_spec.shape.dim.add().size = 3
     expected_tensor_spec.name = ""
     expected_tensor_spec.dtype = dtypes.float64.as_datatype_enum
     expected_tensor_spec.minimum.CopyFrom(
         tensor_util.make_tensor_proto([-2], dtype=dtypes.float64,
                                       shape=[]))
     expected_tensor_spec.maximum.CopyFrom(
         tensor_util.make_tensor_proto([1, 1, 20],
                                       dtype=dtypes.float64,
                                       shape=[3]))
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
Esempio n. 4
0
def _get_element_from_tensor_info(tensor_info, graph):
    """Simplified copy of the deprecated `get_tensor_from_tensor_info`."""
    encoding = tensor_info.WhichOneof("encoding")
    if encoding == "name":
        # We may get operations here in some cases. TensorInfo is a bit of a
        # misnomer if so.
        return graph.as_graph_element(tensor_info.name)
    elif encoding == "coo_sparse":
        return sparse_tensor.SparseTensor(
            graph.get_tensor_by_name(
                tensor_info.coo_sparse.indices_tensor_name),
            graph.get_tensor_by_name(
                tensor_info.coo_sparse.values_tensor_name),
            graph.get_tensor_by_name(
                tensor_info.coo_sparse.dense_shape_tensor_name))
    elif encoding == "composite_tensor":
        spec_proto = struct_pb2.StructuredValue(
            type_spec_value=tensor_info.composite_tensor.type_spec)
        spec = nested_structure_coder.decode_proto(spec_proto)
        components = [
            graph.get_tensor_by_name(component.name)
            for component in tensor_info.composite_tensor.components
        ]
        return spec._from_components(components)  # pylint: disable=protected-access
    else:
        raise ValueError(f"Invalid TensorInfo.encoding: {encoding}. Valid "
                         "encodings are 'name', 'coo_sparse', and "
                         "'composite_tensor'.")
 def testEncodeDecodeSparseTensorSpec(self):
     structure = [sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32)]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected_pbtxt = r"""
   list_value {
     values {
       type_spec_value {
         type_spec_class: SPARSE_TENSOR_SPEC
         type_spec_class_name: 'SparseTensorSpec'
         num_flat_components: 3
         type_state {
           tuple_value {
             # spec._shape
             values {
               tensor_shape_value {
                 dim { size: 10 }
                 dim { size: 20 }
               }
             }
             # spec._dtype
             values { tensor_dtype_value: DT_FLOAT }
           }
         }
       }
     }
   }
 """
     expected = struct_pb2.StructuredValue()
     text_format.Parse(expected_pbtxt, expected)
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
Esempio n. 6
0
 def lookup_tensor_or_sparse_or_composite_tensor(tensor_info):
   """Returns the remapped tensor corresponding to TensorInfo."""
   encoding = tensor_info.WhichOneof('encoding')
   if encoding == 'coo_sparse':
     return tf.SparseTensor(
         lookup_remapped_tensor(tensor_info.coo_sparse.indices_tensor_name),
         lookup_remapped_tensor(tensor_info.coo_sparse.values_tensor_name),
         lookup_remapped_tensor(
             tensor_info.coo_sparse.dense_shape_tensor_name))
   elif encoding == 'composite_tensor':
     components = [lookup_remapped_tensor(info.name)
                   for info in tensor_info.composite_tensor.components]
     spec_proto = struct_pb2.StructuredValue(
         type_spec_value=tensor_info.composite_tensor.type_spec)
     # StrcutureCoder.decode_proto was migrated after TF 2.7 to
     # nested_structure_coder.decode_proto.
     try:
       spec = nested_structure_coder.decode_proto(spec_proto)
     except AttributeError:
       struct_coder = nested_structure_coder.StructureCoder()
       spec = struct_coder.decode_proto(spec_proto)
     return spec._from_components(components)  # pylint: disable=protected-access
   elif encoding == 'name':
     return lookup_remapped_tensor(tensor_info.name)
   else:
     raise ValueError('Unsupported TensorInfo encoding %s' % encoding)
 def testBool(self):
     structure = [False]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected.list_value.values.add().bool_value = False
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 def testDtype(self):
     structure = [dtypes.int64]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     list_value = expected.list_value.values.add()
     list_value.tensor_dtype_value = dtypes.int64.as_datatype_enum
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 def testNone(self):
     structure = [1.0, None]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected.list_value.values.add().float64_value = 1.0
     expected.list_value.values.add().none_value.CopyFrom(
         struct_pb2.NoneValue())
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 def testEncodeDecodeList(self):
     structure = [1.5, 2.5, 3.0]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected.list_value.values.add().float64_value = 1.5
     expected.list_value.values.add().float64_value = 2.5
     expected.list_value.values.add().float64_value = 3.0
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 def testEncodeDecodeDict(self):
     structure = dict(a=3, b=[7, 2.5])
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected.dict_value.fields["a"].int64_value = 3
     list_value = expected.dict_value.fields["b"].list_value
     list_value.values.add().int64_value = 7
     list_value.values.add().float64_value = 2.5
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertIsInstance(decoded["a"], int)
     self.assertEqual(structure, decoded)
Esempio n. 12
0
 def _setup_functions_structures(self):
     """Setup structure for inputs and outputs of restored functions."""
     for name, proto in sorted(self._proto.concrete_functions.items()):
         concrete_function = self._concrete_functions[name]
         # By setting the structured_outputs directly, we can rely on this
         # function_lib.ConcreteFunction object to perform the output repacking
         # logic. The only limitation of that logic is that it only works
         # with output that is convertible to Tensors and the conversion
         # always happens. For example tf.TensorShape([2, 3]) will be
         # converted to Tensor representing [2, 3].
         original_outputs = nested_structure_coder.decode_proto(
             proto.output_signature)
         # The original_outputs here had Tensors converted to TensorSpecs, so
         # the restored function's structured_outputs field will not be
         # exactly the same. Fortunately the repacking logic cares only about
         # the structure; and the unpacking logic cares only about structure
         # and types.
         concrete_function._func_graph.structured_outputs = original_outputs  # pylint: disable=protected-access
         concrete_function._func_graph.structured_input_signature = (  # pylint: disable=protected-access
             nested_structure_coder.decode_proto(
                 proto.canonicalized_input_signature))
         concrete_function._initialize_function_spec()  # pylint: disable=protected-access
Esempio n. 13
0
def _deserialize_function_spec_as_nonmethod(function_spec_proto):
  """Deserialize a FunctionSpec object from its proto representation."""
  typeless_fullargspec = nested_structure_coder.decode_proto(
      function_spec_proto.fullargspec)

  # Convert a method function into a non method.
  if function_spec_proto.is_method:
    if not typeless_fullargspec.args:
      raise NotImplementedError(
          "Cannot deserialize a method function without a named "
          "'self' argument.")
    args = typeless_fullargspec.args[1:]
  else:
    args = typeless_fullargspec.args

  fullargspec = tf_inspect.FullArgSpec(
      args=args,
      varargs=typeless_fullargspec.varargs,
      varkw=typeless_fullargspec.varkw,
      defaults=typeless_fullargspec.defaults,
      kwonlyargs=typeless_fullargspec.kwonlyargs,
      kwonlydefaults=typeless_fullargspec.kwonlydefaults,
      annotations=typeless_fullargspec.annotations)
  input_signature = nested_structure_coder.decode_proto(
      function_spec_proto.input_signature)

  # See `tf.function` and the JitCompile proto for details.
  jit_compile = {
      saved_object_graph_pb2.FunctionSpec.JitCompile.DEFAULT: None,
      saved_object_graph_pb2.FunctionSpec.JitCompile.ON: True,
      saved_object_graph_pb2.FunctionSpec.JitCompile.OFF: False,
  }.get(function_spec_proto.jit_compile)

  return function_spec_lib.FunctionSpec(
      fullargspec=fullargspec,
      is_method=False,
      input_signature=input_signature,
      jit_compile=jit_compile)
    def testRegisteredTypeSpec(self):
        expected_warning = ("Encoding a StructuredValue with type "
                            "NestedStructureTest.RegisteredTypeSpec; loading "
                            "this StructuredValue will require that this type "
                            "be imported and registered")
        structure = {"x": RegisteredTypeSpec()}

        self.assertTrue(nested_structure_coder.can_encode(structure))
        with warnings.catch_warnings(record=True) as w:
            encoded = nested_structure_coder.encode_structure(structure)
            self.assertLen(w, 1)
            self.assertIn(expected_warning, str(w[0].message))
        decoded = nested_structure_coder.decode_proto(encoded)
        self.assertEqual(structure, decoded)
 def testEmptyStructures(self):
     structure = [list(), dict(), tuple()]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected.list_value.values.add().list_value.CopyFrom(
         struct_pb2.ListValue())
     expected.list_value.values.add().dict_value.CopyFrom(
         struct_pb2.DictValue())
     expected.list_value.values.add().tuple_value.CopyFrom(
         struct_pb2.TupleValue())
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 def testEncodeDecodeTuple(self):
     structure = ("hello", [3, (2, 1)])
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected.tuple_value.values.add().string_value = "hello"
     list_value = expected.tuple_value.values.add().list_value
     list_value.values.add().int64_value = 3
     tuple_value = list_value.values.add().tuple_value
     tuple_value.values.add().int64_value = 2
     tuple_value.values.add().int64_value = 1
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 def testEncodeDataSetSpec(self):
     structure = [
         dataset_ops.DatasetSpec({
             "rt":
             ragged_tensor.RaggedTensorSpec([10, None], dtypes.int32),
             "st":
             sparse_tensor.SparseTensorSpec([10, 20], dtypes.float32),
             "t":
             tensor_spec.TensorSpec([10, 8], dtypes.string)
         })
     ]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 def testEncodeDecodeTensorSpecWithNoName(self):
     structure = [tensor_spec.TensorSpec([1, 2, 3], dtypes.int64)]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected_list = expected.list_value
     expected_tensor_spec = expected_list.values.add().tensor_spec_value
     expected_tensor_spec.shape.dim.add().size = 1
     expected_tensor_spec.shape.dim.add().size = 2
     expected_tensor_spec.shape.dim.add().size = 3
     expected_tensor_spec.name = ""
     expected_tensor_spec.dtype = dtypes.int64.as_datatype_enum
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
 def testEncodeDecodeTensorShape(self):
     structure = [tensor_shape.TensorShape([1, 2, 3]), "hello"]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected_list = expected.list_value
     expected_tensor_shape = expected_list.values.add().tensor_shape_value
     expected_tensor_shape.dim.add().size = 1
     expected_tensor_shape.dim.add().size = 2
     expected_tensor_shape.dim.add().size = 3
     expected_tensor_shape = expected_list.values.add(
     ).string_value = "hello"
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
Esempio n. 20
0
def composite_tensor_info_to_type_spec(tensor_info):
  """Converts a `TensorInfo` for a composite tensor to a `TypeSpec` object."""
  if nested_structure_coder is None or struct_pb2 is None:
    raise ValueError("This version of TensorFlow does not support "
                     "composite tensors.")
  if tensor_info.WhichOneof("encoding") != "composite_tensor":
    raise ValueError("Expected a TensorInfo with encoding=composite_tensor")
  spec_proto = struct_pb2.StructuredValue(
      type_spec_value=tensor_info.composite_tensor.type_spec)
  # StrcutureCoder.decode_proto was migrated after TF 2.7 to
  # nested_structure_coder.decode_proto.
  try:
    return nested_structure_coder.decode_proto(spec_proto)
  except AttributeError:
    struct_coder = nested_structure_coder.StructureCoder()
    return struct_coder.decode_proto(spec_proto)
Esempio n. 21
0
 def testBuildTensorInfoRagged(self):
   x = ragged_factory_ops.constant([[1, 2], [3]])
   x_tensor_info = utils.build_tensor_info(x)
   # Check components
   self.assertEqual(x.values.name,
                    x_tensor_info.composite_tensor.components[0].name)
   self.assertEqual(types_pb2.DT_INT32,
                    x_tensor_info.composite_tensor.components[0].dtype)
   self.assertEqual(x.row_splits.name,
                    x_tensor_info.composite_tensor.components[1].name)
   self.assertEqual(types_pb2.DT_INT64,
                    x_tensor_info.composite_tensor.components[1].dtype)
   # Check type_spec.
   spec_proto = struct_pb2.StructuredValue(
       type_spec_value=x_tensor_info.composite_tensor.type_spec)
   spec = nested_structure_coder.decode_proto(spec_proto)
   self.assertEqual(spec, x._type_spec)
    def testEncodeDecodeExtensionTypeSpec(self):
        class Zoo(extension_type.ExtensionType):
            __name__ = "tf.nested_structure_coder_test.Zoo"
            zookeepers: typing.Tuple[str, ...]
            animals: typing.Mapping[str, ops.Tensor]

        structure = [
            Zoo.Spec(zookeepers=["Zoey", "Zack"],
                     animals={"tiger": tensor_spec.TensorSpec([16])})
        ]

        self.assertTrue(nested_structure_coder.can_encode(structure))
        encoded = nested_structure_coder.encode_structure(structure)
        expected_pbtxt = r"""
    list_value {
      values {
        type_spec_value {
          type_spec_class: EXTENSION_TYPE_SPEC
          type_spec_class_name: "tf.nested_structure_coder_test.Zoo.Spec"
          num_flat_components: 1
          type_state {
            tuple_value {
              values {
                tuple_value {
                  values { string_value: "zookeepers" }
                  values { tuple_value {
                    values { string_value: "Zoey" }
                    values { string_value: "Zack" } } } } }
              values {
                tuple_value {
                  values { string_value: "animals" }
                  values { dict_value {
                    fields {
                      key: "tiger"
                      value { tensor_spec_value {
                        shape { dim { size: 16 } }
                        dtype: DT_FLOAT } } } } } } } } } } } }
    """
        expected = struct_pb2.StructuredValue()
        text_format.Parse(expected_pbtxt, expected)
        self.assertEqual(expected, encoded)
        decoded = nested_structure_coder.decode_proto(encoded)
        self.assertEqual(structure, decoded)
Esempio n. 23
0
def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
  """Returns the Tensor or CompositeTensor described by a TensorInfo proto.

  Args:
    tensor_info: A TensorInfo proto describing a Tensor or SparseTensor or
      CompositeTensor.
    graph: The tf.Graph in which tensors are looked up. If None, the
        current default graph is used.
    import_scope: If not None, names in `tensor_info` are prefixed with this
        string before lookup.

  Returns:
    The Tensor or SparseTensor or CompositeTensor in `graph` described by
    `tensor_info`.

  Raises:
    KeyError: If `tensor_info` does not correspond to a tensor in `graph`.
    ValueError: If `tensor_info` is malformed.
  """
  graph = graph or ops.get_default_graph()
  def _get_tensor(name):
    return graph.get_tensor_by_name(
        ops.prepend_name_scope(name, import_scope=import_scope))
  encoding = tensor_info.WhichOneof("encoding")
  if encoding == "name":
    return _get_tensor(tensor_info.name)
  elif encoding == "coo_sparse":
    return sparse_tensor.SparseTensor(
        _get_tensor(tensor_info.coo_sparse.indices_tensor_name),
        _get_tensor(tensor_info.coo_sparse.values_tensor_name),
        _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name))
  elif encoding == "composite_tensor":
    spec_proto = struct_pb2.StructuredValue(
        type_spec_value=tensor_info.composite_tensor.type_spec)
    spec = nested_structure_coder.decode_proto(spec_proto)
    components = [_get_tensor(component.name) for component in
                  tensor_info.composite_tensor.components]
    return nest.pack_sequence_as(spec, components, expand_composites=True)
  else:
    raise ValueError(f"Invalid TensorInfo.encoding: {encoding}. Expected `"
                     "coo_sparse`, `composite_tensor`, or `name` for a dense "
                     "tensor.")
Esempio n. 24
0
  def _infer_inputs(self, layer_node_id, convert_to_shapes=False):
    """Infers input shape of layer from SavedModel functions."""
    call_fn_id = self._search_for_child_node(
        layer_node_id, ['call_and_return_all_conditional_losses'])
    if call_fn_id is None:
      return None

    concrete_functions = (
        self._proto.nodes[call_fn_id].function.concrete_functions)
    if not concrete_functions:
      return None
    call_fn_name = concrete_functions[0]
    call_fn_proto = self._proto.concrete_functions[call_fn_name]
    structured_input_signature = nested_structure_coder.decode_proto(
        call_fn_proto.canonicalized_input_signature)
    inputs = structured_input_signature[0][0]
    if convert_to_shapes:
      return nest.map_structure(lambda spec: spec.shape, inputs)
    else:
      return inputs
 def testEncodeDecodeRaggedTensorSpec(self):
     structure = [
         ragged_tensor.RaggedTensorSpec([1, 2, 3], dtypes.int64, 2,
                                        dtypes.int32)
     ]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected_pbtxt = r"""
   list_value {
     values {
       type_spec_value {
         type_spec_class: RAGGED_TENSOR_SPEC
         type_spec_class_name: 'RaggedTensorSpec'
         num_flat_components: 3
         type_state {
           tuple_value {
             # spec._shape
             values {
               tensor_shape_value {
                 dim { size: 1 }
                 dim { size: 2 }
                 dim { size: 3 }
               }
             }
             # spec._dtype
             values { tensor_dtype_value: DT_INT64 }
             # spec._ragged_rank
             values { int64_value: 2 }
             # spec._row_splits_dtype
             values { tensor_dtype_value: DT_INT32 }
           }
         }
       }
     }
   }
 """
     expected = struct_pb2.StructuredValue()
     text_format.Parse(expected_pbtxt, expected)
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
Esempio n. 26
0
 def from_serialized_proto(cls, proto_string: bytes) -> 'TableInfo':
     """Constructs a TableInfo from a serialized `schema_pb2.TableInfo`."""
     proto = schema_pb2.TableInfo.FromString(proto_string)
     if proto.HasField('signature'):
         signature = nested_structure_coder.decode_proto(proto.signature)
     else:
         signature = None
     return cls(
         name=proto.name,
         sampler_options=proto.sampler_options,
         remover_options=proto.remover_options,
         max_size=proto.max_size,
         max_times_sampled=proto.max_times_sampled,
         rate_limiter_info=proto.rate_limiter_info,
         signature=signature,
         current_size=proto.current_size,
         num_episodes=proto.num_episodes,
         num_deleted_episodes=proto.num_deleted_episodes,
         num_unique_samples=proto.num_unique_samples,
         table_worker_time=proto.table_worker_time,
     )
 def testEncodeDecodeNamedTuple(self):
     named_tuple_type = collections.namedtuple("NamedTuple", ["x", "y"])
     named_tuple = named_tuple_type(x=[1, 2], y="hello")
     self.assertTrue(nested_structure_coder.can_encode(named_tuple))
     encoded = nested_structure_coder.encode_structure(named_tuple)
     expected = struct_pb2.StructuredValue()
     expected_named_tuple = expected.named_tuple_value
     expected_named_tuple.name = "NamedTuple"
     key_value_pair = expected_named_tuple.values.add()
     key_value_pair.key = "x"
     list_value = key_value_pair.value.list_value
     list_value.values.add().int64_value = 1
     list_value.values.add().int64_value = 2
     key_value_pair = expected_named_tuple.values.add()
     key_value_pair.key = "y"
     key_value_pair.value.string_value = "hello"
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(named_tuple._asdict(), decoded._asdict())
     self.assertEqual(named_tuple.__class__.__name__,
                      decoded.__class__.__name__)
 def testEncodeDecodeBoundedTensorSpec(self):
     structure = [
         tensor_spec.BoundedTensorSpec([1, 2, 3], dtypes.int64, 0, 10,
                                       "hello-0-10")
     ]
     self.assertTrue(nested_structure_coder.can_encode(structure))
     encoded = nested_structure_coder.encode_structure(structure)
     expected = struct_pb2.StructuredValue()
     expected_list = expected.list_value
     expected_tensor_spec = expected_list.values.add(
     ).bounded_tensor_spec_value
     expected_tensor_spec.shape.dim.add().size = 1
     expected_tensor_spec.shape.dim.add().size = 2
     expected_tensor_spec.shape.dim.add().size = 3
     expected_tensor_spec.name = "hello-0-10"
     expected_tensor_spec.dtype = dtypes.int64.as_datatype_enum
     expected_tensor_spec.minimum.CopyFrom(
         tensor_util.make_tensor_proto([0], dtype=dtypes.int64, shape=[]))
     expected_tensor_spec.maximum.CopyFrom(
         tensor_util.make_tensor_proto([10], dtype=dtypes.int64, shape=[]))
     self.assertEqual(expected, encoded)
     decoded = nested_structure_coder.decode_proto(encoded)
     self.assertEqual(structure, decoded)
def load_function_def_library(library,
                              saved_object_graph=None,
                              load_shared_name_suffix=None,
                              wrapper_function=None):
    """Load a set of functions as concrete functions without captured inputs.

  Functions names are manipulated during load such that they do not overlap
  with previously created ones.

  Gradients are re-registered under new names. Ops that reference the gradients
  are updated to reflect the new registered names.

  Args:
    library: FunctionDefLibrary proto message.
    saved_object_graph: SavedObjectGraph proto message. If not passed in,
      concrete function structured signatures and outputs will not be set.
    load_shared_name_suffix: If specified, used to uniquify shared
      names. Otherwise, a unique name is generated.
    wrapper_function: An object that will be wrapped on newly created functions.

  Returns:
    Map of original function names in the library to instances of
    `ConcreteFunction` without captured inputs.

  Raises:
    ValueError: if functions dependencies have a cycle.
  """
    library_function_names = set(fdef.signature.name
                                 for fdef in library.function)
    functions = {}
    renamed_functions = {}

    # Our graph building code currently requires functions to be registered with
    # some tf.Graph in order to import functions using the
    # op-name-is-function-name calling convention. To avoid leaking memory into
    # the global default graph when executing eagerly, we create a temporary
    # Graph.
    #
    # TODO(b/205023033): Make this Graph creation unnecessary when executing
    # eagerly by fixing function_def_to_graph_def.
    if ops.executing_eagerly_outside_functions():
        graph = ops.Graph()
    else:
        graph = ops.get_default_graph()

    if load_shared_name_suffix is None:
        load_shared_name_suffix = "_load_{}".format(ops.uid())

    # Custom gradient functions must be re-registered under new UIDs.
    library_gradient_names = {}  # Maps old op type to old function name
    new_gradient_op_types = {}  # Maps old gradient op type to new op type.
    gradients_to_register = {}  # Maps old function name to new op type
    for gdef in library.registered_gradients:
        if gdef.registered_op_type:
            new_op_type = custom_gradient.generate_name()
            old_op_type = compat.as_bytes(gdef.registered_op_type)

            library_gradient_names[old_op_type] = gdef.gradient_func
            new_gradient_op_types[old_op_type] = new_op_type
            gradients_to_register[gdef.gradient_func] = new_op_type

    function_deps = {}
    for fdef in library.function:
        function_deps[fdef.signature.name] = _list_function_deps(
            fdef, library_function_names, library_gradient_names)

    loaded_gradients = {}
    for fdef in _sort_function_defs(library, function_deps):
        copy = _fix_fdef(fdef, functions, load_shared_name_suffix,
                         new_gradient_op_types)

        # Setup function signatures and outputs
        #
        # When concrete functions are created normally (i.e. when they're originally
        # created and not loaded via saved model), the inputs and outputs are
        # calculated based on the values passed in by the user and returned from the
        # original function, respectively. We don't have access to those anymore at
        # restore time, so we must instead pass them to the FuncGraph explicitly.
        structured_input_signature = None
        structured_outputs = None
        if (saved_object_graph is not None and fdef.signature.name
                in saved_object_graph.concrete_functions):
            # TODO(b/204324043): Offload the deserialization of the protos to the
            # first class objects by passing the actual protos. This is blocked on
            # importing `nested_structure_coder` in function.py causing a circular
            # dependency.
            proto = saved_object_graph.concrete_functions[fdef.signature.name]
            structured_input_signature = nested_structure_coder.decode_proto(
                proto.canonicalized_input_signature)
            structured_outputs = nested_structure_coder.decode_proto(
                proto.output_signature)

        # There is no need to copy all functions into the function def graph. It
        # leads to a O(n^2) increase of memory when importing functions and the
        # extra function definitions are a no-op since they already imported as a
        # function before and passed in explicitly (due to the topologic sort
        # import).
        with graph.as_default():
            func_graph = function_def_lib.function_def_to_graph(
                copy,
                structured_input_signature=structured_input_signature,
                structured_outputs=structured_outputs)
        # Restores gradients for function-call ops (not the same as ops that use
        # custom gradients)
        _restore_gradient_functions(func_graph, renamed_functions,
                                    loaded_gradients)

        for dep in function_deps[fdef.signature.name]:
            functions[dep].add_to_graph(func_graph)

        # We do not initialize the new ConcreteFunction's function_spec and/or
        # arg_keywords here (which are used to parse the structured and flat
        # signatures, respectively). ConcreteFunction that are part of a saved
        # function is set up later by recreate_function(); and bare ConcreteFunction
        # is set up by by setup_bare_concrete_function().
        # However, we copy the FunctionDef attributes to the new ConcreteFunction,
        # excluding the "_input_shapes", which may cause an error during input shape
        # initialization at a later stage.
        if "_input_shapes" in copy.attr:
            del copy.attr["_input_shapes"]
        func = function_lib.ConcreteFunction(func_graph, attrs=copy.attr)
        if wrapper_function:
            func = wrapper_function(func)
        func.add_to_graph(graph)

        functions[fdef.signature.name] = func
        renamed_functions[func.name] = func
        if any(op.type == "TRTEngineOp" for op in func_graph.get_operations()):
            # TODO(b/150708051): Remove this hack once TensorRT SavedModel integration
            # is fixed. Currently it's leaking memory to maintain bug compatibility
            # with previous behavior.
            func.add_to_graph(ops.get_default_graph())

        if fdef.signature.name in gradients_to_register:
            gradient_op_type = gradients_to_register[fdef.signature.name]
            loaded_gradients[compat.as_bytes(gradient_op_type)] = func
            ops.RegisterGradient(gradient_op_type)(_gen_gradient_func(func))

    return functions
Esempio n. 30
0
def unpack_pattern(config: Config) -> Pattern:
    if not config.HasField('pattern_structure'):
        return config.flat
    structure = nested_structure_coder.decode_proto(config.pattern_structure)
    return tree.unflatten_as(structure, config.flat)