def assert_serializes(self, fn, parameter_type, expected_fn_type_str): serializer = tensorflow_serialization.tf_computation_serializer( parameter_type, context_stack_impl.context_stack) arg_to_fn = next(serializer) result = fn(arg_to_fn) comp, extra_type_spec = serializer.send(result) deserialized_type = type_serialization.deserialize_type(comp.type) type_test_utils.assert_types_equivalent(deserialized_type, extra_type_spec) self.assertEqual(deserialized_type.compact_representation(), expected_fn_type_str) self.assertEqual(comp.WhichOneof('computation'), 'tensorflow') return comp.tensorflow, extra_type_spec
def _tf_wrapper_fn(parameter_type, name): """Wrapper function to plug Tensorflow logic into the TFF framework. This function is passed through `computation_wrapper.ComputationWrapper`. Documentation its arguments can be found inside the definition of that class. """ del name # Unused. if not type_analysis.is_tensorflow_compatible_type(parameter_type): raise TypeError('`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `StructType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack tf_serializer = tensorflow_serialization.tf_computation_serializer( parameter_type, ctx_stack) result = yield next(tf_serializer) comp_pb, extra_type_spec = tf_serializer.send(result) yield computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def _tf_wrapper_fn(parameter_type, name): """Wrapper function to plug Tensorflow logic into the TFF framework.""" del name # Unused. if not type_analysis.is_tensorflow_compatible_type(parameter_type): raise TypeError( '`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `StructType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack tf_serializer = tensorflow_serialization.tf_computation_serializer( parameter_type, ctx_stack) arg = next(tf_serializer) try: result = yield arg except Exception as e: # pylint: disable=broad-except tf_serializer.throw(e) comp_pb, extra_type_spec = tf_serializer.send(result) tf_serializer.close() yield computation_impl.ConcreteComputation(comp_pb, ctx_stack, extra_type_spec)
def _tf_computation_serializer(fn, parameter_type, context): serializer = tensorflow_serialization.tf_computation_serializer( parameter_type, context) arg_to_fn = next(serializer) result = fn(arg_to_fn) return serializer.send(result)