def test_fetch_value_with_dataset_and_tensor(self): def return_dataset_and_tensor(): return [tf.constant(0), tf.data.Dataset.range(5), tf.constant(5)] executable_return_dataset_and_tensor = computation_impl.ComputationImpl( tensorflow_serialization.serialize_py_fn_as_tf_computation( return_dataset_and_tensor, None, context_stack_impl.context_stack)[0], context_stack_impl.context_stack) x = executable_return_dataset_and_tensor() self.assertEqual(x[0], 0) self.assertEqual(x[1], list(range(5))) self.assertEqual(x[2], 5)
def test_fetch_value_with_empty_dataset_and_tensors(self): def return_dataset(): ds1 = tf.data.Dataset.from_tensor_slices([[1, 1], [1, 1]]) return [tf.constant([0., 0.]), ds1.batch(5).take(0)] executable_return_dataset = computation_impl.ComputationImpl( tensorflow_serialization.serialize_py_fn_as_tf_computation( return_dataset, None, context_stack_impl.context_stack)[0], context_stack_impl.context_stack) x = executable_return_dataset() self.assertEqual(x[0][0], 0.) self.assertEqual(x[0][1], 0.) self.assertEqual(str(x[1][0]), str(np.zeros([0, 2], dtype=np.int32)))
def _federated_computation_wrapper_fn(parameter_type, name): """Wrapper function to plug orchestration logic into the TFF framework. This function is passed through `computation_wrapper.ComputationWrapper`. Documentation its arguments can be found inside the definition of that class. """ ctx_stack = context_stack_impl.context_stack fn_generator = federated_computation_utils.federated_computation_serializer( 'arg' if parameter_type else None, parameter_type, ctx_stack, suggested_name=name) result = yield next(fn_generator) target_lambda, extra_type_spec = fn_generator.send(result) yield computation_impl.ComputationImpl(target_lambda.proto, ctx_stack, extra_type_spec)
def test_fetch_value_with_empty_dataset_and_tensors(self): def return_dataset(): ds1 = tf.data.Dataset.from_tensor_slices([[1, 1], [1, 1]]) return [tf.constant([0., 0.]), ds1.batch(5).take(0)] executable_return_dataset = computation_impl.ComputationImpl( tensorflow_serialization.serialize_py_fn_as_tf_computation( return_dataset, None, context_stack_impl.context_stack)[0], context_stack_impl.context_stack) x = executable_return_dataset() self.assertAllEqual(x[0], [0., 0.]) self.assertEqual(x[1].element_spec, tf.TensorSpec(shape=(None, 2), dtype=tf.int32)) with self.assertRaises(StopIteration): _ = next(iter(x[1]))
def _federated_computation_wrapper_fn(target_fn, parameter_type, unpack, name=None): """Wrapper function to plug orchestration logic in to TFF framework.""" target_fn = function_utils.wrap_as_zero_or_one_arg_callable( target_fn, parameter_type, unpack) ctx_stack = context_stack_impl.context_stack target_lambda = ( federated_computation_utils.zero_or_one_arg_fn_to_building_block( target_fn, 'arg' if parameter_type else None, parameter_type, ctx_stack, suggested_name=name)) return computation_impl.ComputationImpl(target_lambda.proto, ctx_stack)
def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None): """Wrapper function to plug Tensorflow logic in to TFF framework.""" del name # Unused. target_fn = function_utils.wrap_as_zero_or_one_arg_callable( target_fn, parameter_type, unpack) if not type_utils.is_tensorflow_compatible_type(parameter_type): raise TypeError( '`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `NamedTupleType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( target_fn, parameter_type, ctx_stack) return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def test_fetch_value_with_datasets_nested_at_second_level(self): def return_two_datasets(): return [ tf.constant(0), [tf.data.Dataset.range(5), tf.data.Dataset.range(5)] ] executable_return_two_datasets = computation_impl.ComputationImpl( tensorflow_serialization.serialize_py_fn_as_tf_computation( return_two_datasets, None, context_stack_impl.context_stack)[0], context_stack_impl.context_stack) x = executable_return_two_datasets() self.assertEqual(x[0], 0) self.assertEqual(x[1][0], list(range(5))) self.assertEqual(x[1][1], list(range(5)))
def compile(self, computation_to_compile): """Compiles `computation_to_compile`. Args: computation_to_compile: An instance of `computation_base.Computation` to compile. Returns: An instance of `computation_base.Computation` that repeesents the result. """ py_typecheck.check_type(computation_to_compile, computation_base.Computation) computation_proto = computation_impl.ComputationImpl.get_proto( computation_to_compile) py_typecheck.check_type(computation_proto, pb.Computation) comp = building_blocks.ComputationBuildingBlock.from_proto( computation_proto) # TODO(b/113123410): Add a compiler options argument that characterizes the # desired form of the output. To be driven by what the specific backend the # pipeline is targeting is able to understand. Pending a more fleshed out # design of the backend API. # Replace intrinsics with their bodies, for now manually in a fixed order. # TODO(b/113123410): Replace this with a more automated implementation that # does not rely on manual maintenance. comp, _ = value_transformations.replace_all_intrinsics_with_bodies( comp, self._context_stack) # Replaces called lambdas with LET constructs with a single local symbol. comp, _ = transformations.replace_called_lambda_with_block(comp) # Removes maped or applied identities. comp, _ = transformations.remove_mapped_or_applied_identity(comp) # Remove duplicate computations. This is important! otherwise the semantics # non-deterministic computations (e.g. a `tff.tf_computation` depending on # `tf.random`) will give unexpected behavior. Additionally, this may reduce # the amount of calls into TF for some ASTs. comp, _ = transformations.uniquify_reference_names(comp) comp, _ = transformations.extract_computations(comp) comp, _ = transformations.remove_duplicate_computations(comp) return computation_impl.ComputationImpl(comp.proto, self._context_stack)
def _tf_wrapper_fn(target_fn, parameter_type, unpack, name=None): """Wrapper function to plug Tensorflow logic into the TFF framework. This function is passed through `computation_wrapper.ComputationWrapper`. Documentation its arguments can be found inside the definition of that class. """ del name # Unused. target_fn = function_utils.wrap_as_zero_or_one_arg_callable( target_fn, parameter_type, unpack) if not type_analysis.is_tensorflow_compatible_type(parameter_type): raise TypeError('`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `StructType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack comp_pb, extra_type_spec = tensorflow_serialization.serialize_py_fn_as_tf_computation( target_fn, parameter_type, ctx_stack) return computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def _tf_wrapper_fn(parameter_type, name): """Wrapper function to plug Tensorflow logic into the TFF framework. This function is passed through `computation_wrapper.ComputationWrapper`. Documentation its arguments can be found inside the definition of that class. """ del name # Unused. if not type_analysis.is_tensorflow_compatible_type(parameter_type): raise TypeError('`tf_computation`s can accept only parameter types with ' 'constituents `SequenceType`, `StructType` ' 'and `TensorType`; you have attempted to create one ' 'with the type {}.'.format(parameter_type)) ctx_stack = context_stack_impl.context_stack tf_serializer = tensorflow_serialization.tf_computation_serializer( parameter_type, ctx_stack) result = yield next(tf_serializer) comp_pb, extra_type_spec = tf_serializer.send(result) yield computation_impl.ComputationImpl(comp_pb, ctx_stack, extra_type_spec)
def test_fetch_value_with_empty_structured_dataset_and_tensors(self): def return_dataset(): ds1 = tf.data.Dataset.from_tensor_slices( collections.OrderedDict([('a', [1, 1]), ('b', [1, 1])])) return [tf.constant([0., 0.]), ds1.batch(5).take(0)] executable_return_dataset = computation_impl.ComputationImpl( tensorflow_serialization.serialize_py_fn_as_tf_computation( return_dataset, None, context_stack_impl.context_stack)[0], context_stack_impl.context_stack) x = executable_return_dataset() self.assertEqual(x[0][0], 0.) self.assertEqual(x[0][1], 0.) self.assertTrue( np.array_equal(x[1][0].a, np.zeros([0], dtype=np.int32))) self.assertTrue( np.array_equal(x[1][0].b, np.zeros([0], dtype=np.int32)))
def compile(self, computation_to_compile): """Compiles `computation_to_compile`. Args: computation_to_compile: An instance of `computation_base.Computation` to compile. Returns: An instance of `computation_base.Computation` that repeesents the result. """ py_typecheck.check_type(computation_to_compile, computation_base.Computation) computation_proto = computation_impl.ComputationImpl.get_proto( computation_to_compile) # TODO(b/113123410): Add a compiler options argument that characterizes the # desired form of the output. To be driven by what the specific backend the # pipeline is targeting is able to understand. Pending a more fleshed out # design of the backend API. py_typecheck.check_type(computation_proto, pb.Computation) comp = computation_building_blocks.ComputationBuildingBlock.from_proto( computation_proto) # Replace intrinsics with their bodies, for now manually in a fixed order. # TODO(b/113123410): Replace this with a more automated implementation that # does not rely on manual maintenance. for uri, body in six.iteritems(self._intrinsic_bodies): comp, _ = transformations.replace_intrinsic_with_callable( comp, uri, body, self._context_stack) # Replaces called lambdas with LET constructs with a single local symbol. comp, _ = transformations.replace_called_lambda_with_block(comp) # TODO(b/113123410): Add more transformations to simplify and optimize the # structure, e.g., such as: # * removing unnecessary lambdas, # * flatteting the structure, # * merging TensorFlow blocks where appropriate, # * ...and so on. return computation_impl.ComputationImpl(comp.proto, self._context_stack)
def _federated_computation_wrapper_fn(target_fn, parameter_type, unpack, name=None): """Wrapper function to plug orchestration logic into the TFF framework. This function is passed through `computation_wrapper.ComputationWrapper`. Documentation its arguments can be found inside the definition of that class. """ target_fn = function_utils.wrap_as_zero_or_one_arg_callable( target_fn, parameter_type, unpack) ctx_stack = context_stack_impl.context_stack target_lambda = ( federated_computation_utils.zero_or_one_arg_fn_to_building_block( target_fn, 'arg' if parameter_type else None, parameter_type, ctx_stack, suggested_name=name)) return computation_impl.ComputationImpl(target_lambda.proto, ctx_stack)
def test_fetch_value_with_empty_structured_dataset_and_tensors(self): def return_dataset(): ds1 = tf.data.Dataset.from_tensor_slices( collections.OrderedDict([('a', [1, 1]), ('b', [1, 1])])) return [tf.constant([0., 0.]), ds1.batch(5).take(0)] executable_return_dataset = computation_impl.ComputationImpl( tensorflow_serialization.serialize_py_fn_as_tf_computation( return_dataset, None, context_stack_impl.context_stack)[0], context_stack_impl.context_stack) x = executable_return_dataset() self.assertAllEqual(x[0], [0., 0.]) self.assertEqual( tf.data.experimental.get_structure(x[1]), collections.OrderedDict([ ('a', tf.TensorSpec(shape=(None, ), dtype=tf.int32)), ('b', tf.TensorSpec(shape=(None, ), dtype=tf.int32)), ])) with self.assertRaises(StopIteration): _ = next(iter(x[1]))
def _jax_strategy_fn(fn_to_wrap, fn_name, parameter_type, unpack): """Serializes a Python function containing JAX code as a TFF computation. Args: fn_to_wrap: The Python function containing JAX code to be serialized as a computation containing XLA. fn_name: The name for the constructed computation (currently ignored). parameter_type: An instance of `computation_types.Type` that represents the TFF type of the computation parameter, or `None` if there's none. unpack: See `unpack` in `function_utils.create_argument_unpacking_fn`. Returns: An instance of `computation_impl.ComputationImpl` with the constructed computation. """ del fn_name # Unused. unpack_arguments_fn = function_utils.create_argument_unpacking_fn( fn_to_wrap, parameter_type, unpack=unpack) ctx_stack = context_stack_impl.context_stack comp_pb = jax_serialization.serialize_jax_computation( fn_to_wrap, unpack_arguments_fn, parameter_type, ctx_stack) return computation_impl.ComputationImpl(comp_pb, ctx_stack)
def test_something(self): # TODO(b/113112108): Revise these tests after a more complete implementation # is in place. # At the moment, this should succeed, as both the computation body and the # type are well-formed. computation_impl.ComputationImpl( pb.Computation( **{ 'type': type_serialization.serialize_type( computation_types.FunctionType(tf.int32, tf.int32)), 'intrinsic': pb.Intrinsic(uri='whatever') }), context_stack_impl.context_stack) # This should fail, as the proto is not well-formed. self.assertRaises(TypeError, computation_impl.ComputationImpl, pb.Computation(), context_stack_impl.context_stack) # This should fail, as "10" is not an instance of pb.Computation. self.assertRaises(TypeError, computation_impl.ComputationImpl, 10, context_stack_impl.context_stack)
def _federated_computation_wrapper_fn(parameter_type, name): """Wrapper function to plug orchestration logic into the TFF framework. This function is passed through `computation_wrapper.ComputationWrapper`. Documentation its arguments can be found inside the definition of that class. """ ctx_stack = context_stack_impl.context_stack if parameter_type is None: parameter_name = None else: parameter_name = 'arg' fn_generator = federated_computation_utils.federated_computation_serializer( parameter_name=parameter_name, parameter_type=parameter_type, context_stack=ctx_stack, suggested_name=name) arg = next(fn_generator) try: result = yield arg except Exception as e: # pylint: disable=broad-except fn_generator.throw(e) target_lambda, extra_type_spec = fn_generator.send(result) yield computation_impl.ComputationImpl(target_lambda.proto, ctx_stack, extra_type_spec)
def create_dummy_computation_impl(): """Returns a `tff.ComputationImpl` and type.""" comp, type_signature = create_dummy_computation_tensorflow_identity() value = computation_impl.ComputationImpl(comp, context_stack_impl.context_stack) return value, type_signature
def create_dummy_computation_impl(): proto = executor_test_utils.create_dummy_identity_lambda_computation() value = computation_impl.ComputationImpl(proto, context_stack_impl.context_stack) type_signature = type_factory.unary_op(tf.int32) return value, type_signature
def building_block_to_computation(building_block): """Converts a computation building block to a computation impl.""" py_typecheck.check_type(building_block, building_blocks.ComputationBuildingBlock) return computation_impl.ComputationImpl(building_block.proto, context_stack_impl.context_stack)
def _to_computation_impl(building_block): return computation_impl.ComputationImpl(building_block.proto, context_stack_impl.context_stack)