def test_serialize_type_with_tensor_tuple(self): actual_proto = type_serialization.serialize_type([ ('x', tf.int32), ('y', tf.string), tf.float32, ('z', tf.bool), ]) expected_proto = pb.Type(tuple=pb.NamedTupleType(element=[ pb.NamedTupleType.Element( name='x', value=_create_scalar_tensor_type(tf.int32)), pb.NamedTupleType.Element( name='y', value=_create_scalar_tensor_type(tf.string)), pb.NamedTupleType.Element( value=_create_scalar_tensor_type(tf.float32)), pb.NamedTupleType.Element( name='z', value=_create_scalar_tensor_type(tf.bool)), ])) self.assertEqual(actual_proto, expected_proto)
def _deserialize_federated_value( value_proto: executor_pb2.Value) -> _DeserializeReturnType: """Deserializes a value of federated type.""" type_spec = type_serialization.deserialize_type( computation_pb2.Type(federated=value_proto.federated.type)) value = [] for item in value_proto.federated.value: item_value, item_type = deserialize_value(item) type_spec.member.check_assignable_from(item_type) value.append(item_value) if type_spec.all_equal: if len(value) == 1: value = value[0] else: raise ValueError( 'Encountered an all_equal value with {} member constituents. ' 'Expected exactly 1.'.format(len(value))) return value, type_spec
def serialize_type( type_spec: Optional[computation_types.Type]) -> Optional[pb.Type]: """Serializes 'type_spec' as a pb.Type. Note: Currently only serialization for tensor, named tuple, sequence, and function types is implemented. Args: type_spec: A `computation_types.Type`, or `None`. Returns: The corresponding instance of `pb.Type`, or `None` if the argument was `None`. Raises: TypeError: if the argument is of the wrong type. NotImplementedError: for type variants for which serialization is not implemented. """ if type_spec is None: return None cached_proto = _type_serialization_cache.get(type_spec, None) if cached_proto is not None: return cached_proto if type_spec.is_tensor(): proto = pb.Type(tensor=_to_tensor_type_proto(type_spec)) elif type_spec.is_sequence(): proto = pb.Type(sequence=pb.SequenceType( element=serialize_type(type_spec.element))) elif type_spec.is_struct(): proto = pb.Type(struct=pb.StructType(element=[ pb.StructType.Element(name=e[0], value=serialize_type(e[1])) for e in structure.iter_elements(type_spec) ])) elif type_spec.is_function(): proto = pb.Type(function=pb.FunctionType( parameter=serialize_type(type_spec.parameter), result=serialize_type(type_spec.result))) elif type_spec.is_placement(): proto = pb.Type(placement=pb.PlacementType()) elif type_spec.is_federated(): proto = pb.Type( federated=pb.FederatedType(member=serialize_type(type_spec.member), placement=pb.PlacementSpec( value=pb.Placement( uri=type_spec.placement.uri)), all_equal=type_spec.all_equal)) else: raise NotImplementedError _type_serialization_cache[type_spec] = proto return proto
def test_serialize_type_with_tensor_tuple(self): type_signature = computation_types.StructType([ ('x', tf.int32), ('y', tf.string), tf.float32, ('z', tf.bool), ]) actual_proto = type_serialization.serialize_type(type_signature) expected_proto = pb.Type(struct=pb.StructType(element=[ pb.StructType.Element(name='x', value=_create_scalar_tensor_type(tf.int32)), pb.StructType.Element(name='y', value=_create_scalar_tensor_type(tf.string)), pb.StructType.Element( value=_create_scalar_tensor_type(tf.float32)), pb.StructType.Element(name='z', value=_create_scalar_tensor_type(tf.bool)), ])) self.assertEqual(actual_proto, expected_proto)
def serialize_tf2_as_tf_computation(target, parameter_type, unpack=None): """Serializes the 'target' as a TF computation with a given parameter type. Args: target: The entity to convert into and serialize as a TF computation. This can currently only be a Python function or `tf.function`, with arguments matching the 'parameter_type'. parameter_type: The parameter type specification if the target accepts a parameter, or `None` if the target doesn't declare any parameters. Either an instance of `types.Type`, or something that's convertible to it by `types.to_type()`. unpack: Whether to always unpack the parameter_type. Necessary for support of polymorphic tf2_computations. Returns: The constructed `pb.Computation` instance with the `pb.TensorFlow` variant set. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the signature of the target is not compatible with the given parameter type. """ py_typecheck.check_callable(target) parameter_type = computation_types.to_type(parameter_type) argspec = function_utils.get_argspec(target) if argspec.args and parameter_type is None: raise ValueError( 'Expected the target to declare no parameters, found {}.'.format( repr(argspec.args))) # In the codepath for TF V1 based serialization (tff.tf_computation), # we get the "wrapped" function to serialize. Here, target is the # raw function to be wrapped; however, we still need to know if # the parameter_type should be unpacked into multiple args and kwargs # in order to construct the TensorSpecs to be passed in the call # to get_concrete_fn below. unpack = function_utils.infer_unpack_needed(target, parameter_type, unpack) arg_typespecs, kwarg_typespecs, parameter_binding = ( graph_utils.get_tf_typespec_and_binding(parameter_type, arg_names=argspec.args, unpack=unpack)) # Pseudo-global to be appended to once when target_poly below is traced. type_and_binding_slot = [] # N.B. To serialize a tf.function or eager python code, # the return type must be a flat list, tuple, or dict. However, the # tff.tf_computation must be able to handle structured inputs and outputs. # Thus, we intercept the result of calling the original target fn, introspect # its structure to create a result_type and bindings, and then return a # flat dict output. It is this new "unpacked" tf.function that we will # serialize using tf.saved_model.save. # # TODO(b/117428091): The return type limitation is primarily a limitation of # SignatureDefs and therefore of the signatures argument to # tf.saved_model.save. tf.functions attached to objects and loaded back with # tf.saved_model.load can take/return nests; this might offer a better # approach to the one taken here. @tf.function(autograph=False) def target_poly(*args, **kwargs): result = target(*args, **kwargs) result_dict, result_type, result_binding = ( graph_utils.get_tf2_result_dict_and_binding(result)) assert not type_and_binding_slot # A "side channel" python output. type_and_binding_slot.append((result_type, result_binding)) return result_dict # Triggers tracing so that type_and_binding_slot is filled. cc_fn = target_poly.get_concrete_function(*arg_typespecs, **kwarg_typespecs) assert len(type_and_binding_slot) == 1 result_type, result_binding = type_and_binding_slot[0] # N.B. Note that cc_fn does *not* accept the same args and kwargs as the # Python target_poly; instead, it must be called with **kwargs based on the # unique names embedded in the TensorSpecs inside arg_typespecs and # kwarg_typespecs. The (preliminary) parameter_binding tracks the mapping # between these tensor names and the components of the (possibly nested) TFF # input type. When cc_fn is serialized, concrete tensors for each input are # introduced, and the call finalize_binding(parameter_binding, # sigs['serving_default'].inputs) updates the bindings to reference these # concrete tensors. # Associate vars with unique names and explicitly attach to the Checkpoint: var_dict = { 'var{:02d}'.format(i): v for i, v in enumerate(cc_fn.graph.variables) } saveable = tf.train.Checkpoint(fn=target_poly, **var_dict) try: # TODO(b/122081673): All we really need is the meta graph def, we could # probably just load that directly, e.g., using parse_saved_model from # tensorflow/python/saved_model/loader_impl.py, but I'm not sure we want to # depend on that presumably non-public symbol. Perhaps TF can expose a way # to just get the MetaGraphDef directly without saving to a tempfile? This # looks like a small change to v2.saved_model.save(). outdir = tempfile.mkdtemp('savedmodel') tf.saved_model.save(saveable, outdir, signatures=cc_fn) graph = tf.Graph() with tf.compat.v1.Session(graph=graph) as sess: mgd = tf.saved_model.loader.load( sess, tags=[tf.saved_model.tag_constants.SERVING], export_dir=outdir) finally: shutil.rmtree(outdir) sigs = mgd.signature_def # TODO(b/123102455): Figure out how to support the init_op. The meta graph def # contains sigs['__saved_model_init_op'].outputs['__saved_model_init_op']. It # probably won't do what we want, because it will want to read from # Checkpoints, not just run Variable initializerse (?). The right solution may # be to grab the target_poly.get_initialization_function(), and save a sig for # that. # Now, traverse the signature from the MetaGraphDef to find # find the actual tensor names and write them into the bindings. finalize_binding(parameter_binding, sigs['serving_default'].inputs) finalize_binding(result_binding, sigs['serving_default'].outputs) annotated_type = computation_types.FunctionType(parameter_type, result_type) return pb.Computation(type=pb.Type(function=pb.FunctionType( parameter=type_serialization.serialize_type(parameter_type), result=type_serialization.serialize_type(result_type))), tensorflow=pb.TensorFlow( graph_def=serialization_utils.pack_graph_def( mgd.graph_def), parameter=parameter_binding, result=result_binding)), annotated_type
def serialize_py_fn_as_tf_computation(target, parameter_type, context_stack): """Serializes the 'target' as a TF computation with a given parameter type. See also `serialize_tf2_as_tf_computation` for TensorFlow 2 serialization. Args: target: The entity to convert into and serialize as a TF computation. This can currently only be a Python function. In the future, we will add here support for serializing the various kinds of non-eager and eager functions, and eventually aim at full support for and compliance with TF 2.0. This function is currently required to declare either zero parameters if `parameter_type` is `None`, or exactly one parameter if it's not `None`. The nested structure of this parameter must correspond to the structure of the 'parameter_type'. In the future, we may support targets with multiple args/keyword args (to be documented in the API and referenced from here). parameter_type: The parameter type specification if the target accepts a parameter, or `None` if the target doesn't declare any parameters. Either an instance of `types.Type`, or something that's convertible to it by `types.to_type()`. context_stack: The context stack to use. Returns: A tuple of (`pb.Computation`, `tff.Type`), where the computation contains the instance with the `pb.TensorFlow` variant set, and the type is an instance of `tff.Type`, potentially including Python container annotations, for use by TensorFlow computation wrappers. Raises: TypeError: If the arguments are of the wrong types. ValueError: If the signature of the target is not compatible with the given parameter type. """ # TODO(b/113112108): Support a greater variety of target type signatures, # with keyword args or multiple args corresponding to elements of a tuple. # Document all accepted forms with examples in the API, and point to there # from here. py_typecheck.check_type(target, types.FunctionType) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) parameter_type = computation_types.to_type(parameter_type) argspec = inspect.getargspec(target) # pylint: disable=deprecated-method with tf.Graph().as_default() as graph: args = [] if parameter_type is not None: if len(argspec.args) != 1: raise ValueError( 'Expected the target to declare exactly one parameter, ' 'found {}.'.format(repr(argspec.args))) parameter_name = argspec.args[0] parameter_value, parameter_binding = graph_utils.stamp_parameter_in_graph( parameter_name, parameter_type, graph) args.append(parameter_value) else: if argspec.args: raise ValueError( 'Expected the target to declare no parameters, found {}.'. format(repr(argspec.args))) parameter_binding = None context = tf_computation_context.TensorFlowComputationContext(graph) with context_stack.install(context): result = target(*args) # TODO(b/122081673): This needs to change for TF 2.0. We may also # want to allow the person creating a tff.tf_computation to specify # a different initializer; e.g., if it is known that certain # variables will be assigned immediately to arguments of the function, # then it is wasteful to initialize them before this. # # The following is a bit of a work around: the collections below may # contain variables more than once, hence we throw into a set. TFF needs # to ensure all variables are initialized, but not all variables are # always in the collections we expect. tff.learning._KerasModel tries to # pull Keras variables (that may or may not be in GLOBAL_VARIABLES) into # TFF_MODEL_VARIABLES for now. all_variables = set( tf.compat.v1.global_variables() + tf.compat.v1.local_variables() + tf.compat.v1.get_collection( graph_keys.GraphKeys.VARS_FOR_TFF_TO_INITIALIZE)) if all_variables: # Use a readable but not-too-long name for the init_op. name = 'init_op_for_' + '_'.join( [v.name.replace(':0', '') for v in all_variables]) if len(name) > 50: name = 'init_op_for_{}_variables'.format( len(all_variables)) with tf.control_dependencies(context.init_ops): # Before running the main new init op, run any initializers for sub- # computations from context.init_ops. Variables from import_graph_def # will not make it into the global collections, and so will not be # initialized without this code path. init_op_name = tf.compat.v1.initializers.variables( all_variables, name=name).name elif context.init_ops: init_op_name = tf.group(*context.init_ops, name='subcomputation_init_ops').name else: init_op_name = None result_type, result_binding = graph_utils.capture_result_from_graph( result, graph) annotated_type = computation_types.FunctionType(parameter_type, result_type) return pb.Computation(type=pb.Type(function=pb.FunctionType( parameter=type_serialization.serialize_type(parameter_type), result=type_serialization.serialize_type(result_type))), tensorflow=pb.TensorFlow( graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=parameter_binding, result=result_binding, initialize_op=init_op_name)), annotated_type
def _tuple_type_proto(elements): return pb.Type(tuple=pb.NamedTupleType(element=elements))
def test_serialize_type_with_string_sequence(self): actual_proto = type_serialization.serialize_type( computation_types.SequenceType(tf.string)) expected_proto = pb.Type(sequence=pb.SequenceType( element=_create_scalar_tensor_type(tf.string))) self.assertEqual(actual_proto, expected_proto)
def test_serialize_tensor_type(self, dtype, shape): actual_proto = type_serialization.serialize_type((dtype, shape)) expected_proto = pb.Type(tensor=pb.TensorType( dtype=dtype.as_datatype_enum, dims=_shape_to_dims(shape))) self.assertEqual(actual_proto, expected_proto)
def _create_scalar_tensor_type(dtype): return pb.Type(tensor=pb.TensorType(dtype=dtype.as_datatype_enum))
def test_serialize_type_with_placement(self): actual_proto = type_serialization.serialize_type( computation_types.PlacementType()) expected_proto = pb.Type(placement=pb.PlacementType()) self.assertEqual(actual_proto, expected_proto)
def _tuple_type_proto(elements): return pb.Type(struct=pb.StructType(element=elements))