def should_transform(self, comp): if not (type_analysis.is_tensorflow_compatible_type(comp.type_signature) or (comp.type_signature.is_function() and type_analysis.is_tensorflow_compatible_type( comp.type_signature.parameter) and type_analysis.is_tensorflow_compatible_type( comp.type_signature.result))): return False elif comp.is_compiled_computation() or ( comp.is_call() and comp.function.is_compiled_computation()): # These represent the final result of TF generation; no need to transform. return False unbound_refs = transformation_utils.get_map_of_unbound_references( comp)[comp] if unbound_refs: # We cannot represent these captures without further information. return False if tree_analysis.contains_types(comp, building_blocks.Intrinsic): return False return True
def is_allowed_client_data_type( type_spec: computation_types.Type) -> bool: if type_spec.is_sequence(): return type_analysis.is_tensorflow_compatible_type( type_spec.element) elif type_spec.is_struct(): return all( is_allowed_client_data_type(element_type) for element_type in type_spec.children()) else: return False
def _tf_wrapper_fn(parameter_type, name): """Wrapper function to plug Tensorflow logic into the TFF framework. This function is passed through `computation_wrapper.ComputationWrapper`. Documentation its arguments can be found inside the definition of that class. """ del name # Unused. if not type_analysis.is_tensorflow_compatible_type(parameter_type): raise TypeError('`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `StructType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack tf_serializer = tensorflow_serialization.tf_computation_serializer( parameter_type, ctx_stack) result = yield next(tf_serializer) comp_pb, extra_type_spec = tf_serializer.send(result) yield computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None): """Wrapper function to plug Tensorflow logic into the TFF framework. This function is passed through `computation_wrapper.ComputationWrapper`. Documentation its arguments can be found inside the definition of that class. """ del name # Unused. target_fn = function_utils.wrap_as_zero_or_one_arg_callable( target_fn, parameter_type, unpack) if not type_analysis.is_tensorflow_compatible_type(parameter_type): raise TypeError('`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `StructType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( target_fn, parameter_type, ctx_stack) return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def _tf_wrapper_fn(parameter_type, name): """Wrapper function to plug Tensorflow logic into the TFF framework.""" del name # Unused. if not type_analysis.is_tensorflow_compatible_type(parameter_type): raise TypeError( '`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `StructType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack tf_serializer = tensorflow_serialization.tf_computation_serializer( parameter_type, ctx_stack) arg = next(tf_serializer) try: result = yield arg except Exception as e: # pylint: disable=broad-except tf_serializer.throw(e) comp_pb, extra_type_spec = tf_serializer.send(result) tf_serializer.close() yield computation_impl.ConcreteComputation(comp_pb, ctx_stack, extra_type_spec)