def test_returns_correct_structure_with_no_unbound_references(self): concrete_int = building_block_factory.create_tensorflow_constant( tf.int32, 1) first_tf_id = building_block_factory.create_compiled_identity(tf.int32) called_tf_id = building_blocks.Call(first_tf_id, concrete_int) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) second_tf_id = building_block_factory.create_compiled_identity( tf.int32) second_called = building_blocks.Call(second_tf_id, ref_to_call) ref_to_second_call = building_blocks.Reference( 'second_call', called_tf_id.type_signature) block_locals = [('call', called_tf_id), ('second_call', second_called)] block = building_blocks.Block( block_locals, building_blocks.Tuple([ref_to_second_call, ref_to_second_call])) tf_representing_block, _ = transformations.create_tensorflow_representing_block( block) self.assertEqual(tf_representing_block.type_signature, block.type_signature) self.assertIsInstance(tf_representing_block, building_blocks.Call) self.assertIsInstance(tf_representing_block.function, building_blocks.CompiledComputation) self.assertIsNone(tf_representing_block.argument)
def test_returns_single_called_graph_after_resolving_multiple_variables( self): ref_to_int = building_blocks.Reference('var', tf.int32) first_tf_id = building_block_factory.create_compiled_identity(tf.int32) called_tf_id = building_blocks.Call(first_tf_id, ref_to_int) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) second_tf_id = building_block_factory.create_compiled_identity( tf.int32) second_called = building_blocks.Call(second_tf_id, ref_to_call) ref_to_second_call = building_blocks.Reference( 'second_call', called_tf_id.type_signature) block_locals = [('call', called_tf_id), ('second_call', second_called)] block = building_blocks.Block(block_locals, ref_to_second_call) tf_representing_block, _ = transformations.create_tensorflow_representing_block( block) self.assertEqual(tf_representing_block.type_signature, block.type_signature) self.assertIsInstance(tf_representing_block, building_blocks.Call) self.assertIsInstance(tf_representing_block.function, building_blocks.CompiledComputation) self.assertIsInstance(tf_representing_block.argument, building_blocks.Reference) self.assertEqual(tf_representing_block.argument.name, 'var')
def __setattr__(self, name, value): py_typecheck.check_type(name, str) _check_is_optionally_federated_named_tuple( self, "__setattr__('{}', {})".format(name, value)) value_comp = ValueImpl.get_comp( to_value(value, None, self._context_stack)) if _is_federated_named_tuple(self): new_comp = building_block_factory.create_federated_setattr_call( self._comp, name, value_comp) super().__setattr__('_comp', new_comp) return named_tuple_setattr_lambda = building_block_factory.create_named_tuple_setattr_lambda( self._comp.type_signature, name, value_comp) new_comp = building_blocks.Call(named_tuple_setattr_lambda, self._comp) super(ValueImpl, self).__setattr__('_comp', new_comp)
def test_returns_string_for_call_with_arg(self): fn_type = computation_types.FunctionType(tf.int32, tf.int32) fn = building_blocks.Reference('a', fn_type) arg = building_blocks.Data('data', tf.int32) comp = building_blocks.Call(fn, arg) self.assertEqual(comp.compact_representation(), 'a(data)') self.assertEqual(comp.formatted_representation(), 'a(data)') # pyformat: disable self.assertEqual( comp.structural_representation(), ' Call\n' ' / \\\n' 'Ref(a) data' )
def test_strip_placement_with_called_lambda(self): int_type = computation_types.TensorType(tf.int32) server_int_type = computation_types.at_server(int_type) federated_ref = building_blocks.Reference('outer', server_int_type) inner_federated_ref = building_blocks.Reference( 'inner', server_int_type) identity_lambda = building_blocks.Lambda('inner', server_int_type, inner_federated_ref) before = building_blocks.Call(identity_lambda, federated_ref) after, modified = tree_transformations.strip_placement(before) self.assertTrue(modified) self.assert_has_no_intrinsics_nor_federated_types(after) type_test_utils.assert_types_identical(before.type_signature, server_int_type) type_test_utils.assert_types_identical(after.type_signature, int_type)
def test_basic_functionality_of_lambda_class(self): arg_name = 'arg' arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)), ('x', tf.int32)] arg = building_blocks.Reference(arg_name, arg_type) arg_f = building_blocks.Selection(arg, name='f') arg_x = building_blocks.Selection(arg, name='x') x = building_blocks.Lambda( arg_name, arg_type, building_blocks.Call(arg_f, building_blocks.Call(arg_f, arg_x))) self.assertEqual(str(x.type_signature), '(<f=(int32 -> int32),x=int32> -> int32)') self.assertEqual(x.parameter_name, arg_name) self.assertEqual(str(x.parameter_type), '<f=(int32 -> int32),x=int32>') self.assertEqual(x.result.compact_representation(), 'arg.f(arg.f(arg.x))') arg_type_repr = ( 'NamedTupleType([' '(\'f\', FunctionType(TensorType(tf.int32), TensorType(tf.int32))), ' '(\'x\', TensorType(tf.int32))])') self.assertEqual( repr(x), 'Lambda(\'arg\', {0}, ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Call(Selection(Reference(\'arg\', {0}), name=\'f\'), ' 'Selection(Reference(\'arg\', {0}), name=\'x\'))))'.format( arg_type_repr)) self.assertEqual(x.compact_representation(), '(arg -> arg.f(arg.f(arg.x)))') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'lambda') self.assertEqual(getattr(x_proto, 'lambda').parameter_name, arg_name) self.assertEqual(str(getattr(x_proto, 'lambda').result), str(x.result.proto)) self._serialize_deserialize_roundtrip_test(x)
def test_block_lambda_block_lambda(self): x_ref = building_blocks.Reference('a', tf.int32) inner_lambda = building_blocks.Lambda('a', tf.int32, x_ref) called_lambda = building_blocks.Call(inner_lambda, x_ref) lower_block = building_blocks.Block([('a', x_ref), ('a', x_ref)], called_lambda) second_lambda = building_blocks.Lambda('a', tf.int32, lower_block) second_call = building_blocks.Call(second_lambda, x_ref) data = building_blocks.Data('data', tf.int32) last_block = building_blocks.Block([('a', data), ('a', x_ref)], second_call) transformed_comp, modified = tree_transformations.uniquify_reference_names( last_block) self.assertEqual( last_block.compact_representation(), '(let a=data,a=a in (a -> (let a=a,a=a in (a -> a)(a)))(a))') self.assertEqual( transformed_comp.compact_representation(), '(let a=data,_var1=a in (_var2 -> (let _var3=_var2,_var4=_var3 in (_var5 -> _var5)(_var4)))(_var1))' ) tree_analysis.check_has_unique_names(transformed_comp) self.assertTrue(modified)
def test_does_not_remove_called_lambda(self): fn = building_block_test_utils.create_identity_function('a', tf.int32) arg = building_blocks.Data('data', tf.int32) call = building_blocks.Call(fn, arg) comp = call transformed_comp, modified = tree_transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(transformed_comp.compact_representation(), comp.compact_representation()) self.assertEqual(transformed_comp.compact_representation(), '(a -> a)(data)') self.assertEqual(transformed_comp.type_signature, comp.type_signature) self.assertFalse(modified)
def test_executes_correctly_after_resolving_multiple_variables(self): ref_to_int = building_blocks.Reference('var', tf.int32) first_tf_id_type = computation_types.TensorType(tf.int32) first_tf_id = building_block_factory.create_compiled_identity( first_tf_id_type) called_tf_id = building_blocks.Call(first_tf_id, ref_to_int) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) second_tf_id_type = computation_types.TensorType(tf.int32) second_tf_id = building_block_factory.create_compiled_identity( second_tf_id_type) second_called = building_blocks.Call(second_tf_id, ref_to_call) ref_to_second_call = building_blocks.Reference( 'second_call', called_tf_id.type_signature) block_locals = [('call', called_tf_id), ('second_call', second_called)] block = building_blocks.Block(block_locals, ref_to_second_call) tf_representing_block, _ = transformations.create_tensorflow_representing_block( block) result_one = test_utils.run_tensorflow( tf_representing_block.function.proto, 1) self.assertEqual(result_one, 1) result_zero = test_utils.run_tensorflow( tf_representing_block.function.proto, 0) self.assertEqual(result_zero, 0)
def test_broadcast_dependent_on_aggregate_fails_well(self): cf = test_utils.get_temperature_sensor_example() it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) next_comp = test_utils.computation_to_building_block(it.next) top_level_param = building_blocks.Reference(next_comp.parameter_name, next_comp.parameter_type) first_result = building_blocks.Call(next_comp, top_level_param) middle_param = building_blocks.Struct([ building_blocks.Selection(first_result, index=0), building_blocks.Selection(top_level_param, index=1) ]) second_result = building_blocks.Call(next_comp, middle_param) not_reducible = building_blocks.Lambda(next_comp.parameter_name, next_comp.parameter_type, second_result) not_reducible_it = iterative_process.IterativeProcess( it.initialize, computation_wrapper_instances.building_block_to_computation( not_reducible)) with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'): canonical_form_utils.get_canonical_form_for_iterative_process( not_reducible_it)
def sequence_reduce(self, value, zero, op): """Implements `sequence_reduce` as defined in `api/intrinsics.py`.""" value = value_impl.to_value(value, None, self._context_stack) zero = value_impl.to_value(zero, None, self._context_stack) op = value_impl.to_value(op, None, self._context_stack) if isinstance(value.type_signature, computation_types.SequenceType): element_type = value.type_signature.element else: py_typecheck.check_type(value.type_signature, computation_types.FederatedType) py_typecheck.check_type(value.type_signature.member, computation_types.SequenceType) element_type = value.type_signature.member.element op_type_expected = type_factory.reduction_op(zero.type_signature, element_type) if not type_analysis.is_assignable_from(op_type_expected, op.type_signature): raise TypeError('Expected an operator of type {}, got {}.'.format( op_type_expected, op.type_signature)) value = value_impl.ValueImpl.get_comp(value) zero = value_impl.ValueImpl.get_comp(zero) op = value_impl.ValueImpl.get_comp(op) if isinstance(value.type_signature, computation_types.SequenceType): return value_impl.ValueImpl( building_block_factory.create_sequence_reduce(value, zero, op), self._context_stack) else: value_type = computation_types.SequenceType(element_type) intrinsic_type = computation_types.FunctionType(( value_type, zero.type_signature, op.type_signature, ), op.type_signature.result) intrinsic = building_blocks.Intrinsic( intrinsic_defs.SEQUENCE_REDUCE.uri, intrinsic_type) ref = building_blocks.Reference('arg', value_type) tup = building_blocks.Tuple((ref, zero, op)) call = building_blocks.Call(intrinsic, tup) fn = building_blocks.Lambda(ref.name, ref.type_signature, call) fn_impl = value_impl.ValueImpl(fn, self._context_stack) if value.type_signature.placement in [ placement_literals.SERVER, placement_literals.CLIENTS ]: return self.federated_map(fn_impl, value) else: raise TypeError('Unsupported placement {}.'.format( value.type_signature.placement))
def __add__(self, other): other = to_value(other, None, self._context_stack) if not type_utils.are_equivalent_types(self.type_signature, other.type_signature): raise TypeError('Cannot add {} and {}.'.format( self.type_signature, other.type_signature)) return ValueImpl( building_blocks.Call( building_blocks.Intrinsic( intrinsic_defs.GENERIC_PLUS.uri, computation_types.FunctionType( [self.type_signature, self.type_signature], self.type_signature)), ValueImpl.get_comp( to_value([self, other], None, self._context_stack))), self._context_stack)
def __add__(self, other): other = to_value(other, None, self._context_stack) if not self.type_signature.is_equivalent_to(other.type_signature): raise TypeError('Cannot add {} and {}.'.format( self.type_signature, other.type_signature)) # TODO(b/159281959): Follow up and bind a reference here. return ValueImpl( building_blocks.Call( building_blocks.Intrinsic( intrinsic_defs.GENERIC_PLUS.uri, computation_types.FunctionType( [self.type_signature, self.type_signature], self.type_signature)), ValueImpl.get_comp( to_value([self, other], None, self._context_stack))), self._context_stack)
def create_whimsy_called_intrinsic(parameter_name, parameter_type=tf.int32): r"""Returns a whimsy called intrinsic. Call / \ intrinsic Ref(x) Args: parameter_name: The name of the parameter. parameter_type: The type of the parameter. """ intrinsic_type = computation_types.FunctionType(parameter_type, parameter_type) intrinsic = building_blocks.Intrinsic('intrinsic', intrinsic_type) ref = building_blocks.Reference(parameter_name, parameter_type) return building_blocks.Call(intrinsic, ref)
def test_returns_string_for_call_with_no_arg(self): fn_type = computation_types.FunctionType(None, tf.int32) fn = building_blocks.Reference('a', fn_type) comp = building_blocks.Call(fn) compact_string = comp.compact_representation() self.assertEqual(compact_string, 'a()') formatted_string = comp.formatted_representation() self.assertEqual(formatted_string, 'a()') structural_string = comp.structural_representation() # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' /\n' 'Ref(a)' )
def __setattr__(self, name, value): py_typecheck.check_type(name, str) _check_struct_or_federated_struct(self, name) value_comp = ValueImpl.get_comp( to_value(value, None, self._context_stack)) if _is_federated_named_tuple(self): new_comp = building_block_factory.create_federated_setattr_call( self._comp, name, value_comp) super().__setattr__('_comp', new_comp) return named_tuple_setattr_lambda = building_block_factory.create_named_tuple_setattr_lambda( self.type_signature, name, value_comp) new_comp = building_blocks.Call(named_tuple_setattr_lambda, self._comp) fc_context = self._context_stack.current ref = fc_context.bind_computation_to_reference(new_comp) super().__setattr__('_comp', ref)
def __add__(self, other): other = to_value(other, None, self._context_stack) if not self.type_signature.is_equivalent_to(other.type_signature): raise TypeError('Cannot add {} and {}.'.format( self.type_signature, other.type_signature)) call = building_blocks.Call( building_blocks.Intrinsic( intrinsic_defs.GENERIC_PLUS.uri, computation_types.FunctionType( [self.type_signature, self.type_signature], self.type_signature)), ValueImpl.get_comp( to_value([self, other], None, self._context_stack))) fc_context = self._context_stack.current ref = fc_context.bind_computation_to_reference(call) return ValueImpl(ref, self._context_stack)
def test_replaces_lambda_to_called_graph_with_tf_of_same_type(self): identity_tf_block_type = computation_types.TensorType(tf.int32) identity_tf_block = building_block_factory.create_compiled_identity( identity_tf_block_type) int_ref = building_blocks.Reference('x', tf.int32) called_tf_block = building_blocks.Call(identity_tf_block, int_ref) lambda_wrapper = building_blocks.Lambda('x', tf.int32, called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to( lambda_wrapper.type_signature) result = test_utils.run_tensorflow(parsed.proto, 2) self.assertEqual(2, result)
def __call__(self, *args, **kwargs): if not self.type_signature.is_function(): raise SyntaxError( 'Function-like invocation is only supported for values of functional ' 'types, but the value being invoked is of type {} that does not ' 'support invocation.'.format(self.type_signature)) if args or kwargs: args = [to_value(x, None) for x in args] kwargs = {k: to_value(v, None) for k, v in kwargs.items()} arg = function_utils.pack_args(self.type_signature.parameter, args, kwargs) arg = to_value(arg, None).comp else: arg = None call = building_blocks.Call(self._comp, arg) ref = _bind_computation_to_reference(call, 'calling a `tff.Value`') return Value(ref)
def construct_tensorflow_calling_lambda_on_concrete_arg( parameter: building_blocks.Reference, body: building_blocks.ComputationBuildingBlock, concrete_arg: building_blocks.ComputationBuildingBlock): """Generates TensorFlow for lambda invocation with given arg, body and param. That is, generates TensorFlow block encapsulating the logic represented by invoking a function with parameter `parameter` and body `body`, with argument `concrete_arg`. Via the guarantee made in `compiled_computation_transforms.TupleCalledGraphs`, this function makes the claim that the computations which define `concrete_arg` will be executed exactly once in the generated TenosorFlow. Args: parameter: Instance of `building_blocks.Reference` defining the parameter of the function to be generated and invoked, as described above. After calling this transformation, every instance of parameter` in `body` will represent a reference to `concrete_arg`. body: `building_blocks.ComputationBuildingBlock` representing the body of the function for which we are generating TensorFlow. concrete_arg: `building_blocks.ComputationBuildingBlock` representing the argument to be passed to the resulting function. `concrete_arg` will then be referred to by every occurrence of `parameter` in `body`. Therefore `concrete_arg` must have an equivalent type signature to that of `parameter`. Returns: A called `building_blocks.CompiledComputation`, as specified above. Raises: TypeError: If the arguments are of the wrong types, or the type signature of `concrete_arg` does not match that of `parameter`. """ py_typecheck.check_type(parameter, building_blocks.Reference) py_typecheck.check_type(body, building_blocks.ComputationBuildingBlock) py_typecheck.check_type(concrete_arg, building_blocks.ComputationBuildingBlock) type_analysis.check_equivalent_types(parameter.type_signature, concrete_arg.type_signature) encapsulating_lambda = _generate_simple_tensorflow( building_blocks.Lambda(parameter.name, parameter.type_signature, body)) comp_called = _generate_simple_tensorflow( building_blocks.Call(encapsulating_lambda, concrete_arg)) return comp_called
def test_replaces_lambda_to_called_graph_with_tf_of_same_type(self): identity_tf_block = building_block_factory.create_compiled_identity( tf.int32) int_ref = building_blocks.Reference('x', tf.int32) called_tf_block = building_blocks.Call(identity_tf_block, int_ref) lambda_wrapper = building_blocks.Lambda('x', tf.int32, called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) self.assertEqual(exec_lambda(2), exec_tf(2))
def _wrap_constant_as_value(const, context_stack): """Wraps the given Python constant as a `tff.Value`. Args: const: Python constant to be converted to TFF value. Anything convertible to Tensor via `tf.constant` can be passed in. context_stack: The context stack to use. Returns: An instance of `value_base.Value`. """ py_typecheck.check_type(context_stack, context_stack_base.ContextStack) tf_comp, _ = tensorflow_serialization.serialize_py_fn_as_tf_computation( lambda: tf.constant(const), None, context_stack) compiled_comp = building_blocks.CompiledComputation(tf_comp) called_comp = building_blocks.Call(compiled_comp) return ValueImpl(called_comp, context_stack)
def __call__(self, *args, **kwargs): if not isinstance(self._comp.type_signature, computation_types.FunctionType): raise SyntaxError( 'Function-like invocation is only supported for values of functional ' 'types, but the value being invoked is of type {} that does not ' 'support invocation.'.format(self._comp.type_signature)) if args or kwargs: args = [to_value(x, None, self._context_stack) for x in args] kwargs = { k: to_value(v, None, self._context_stack) for k, v in kwargs.items() } arg = function_utils.pack_args(self._comp.type_signature.parameter, args, kwargs, self._context_stack.current) arg = ValueImpl.get_comp(to_value(arg, None, self._context_stack)) else: arg = None return ValueImpl(building_blocks.Call(self._comp, arg), self._context_stack)
def test_with_structure_replacing_federated_map(self): function_type = computation_types.FunctionType(tf.int32, tf.int32) tuple_ref = building_blocks.Reference('arg', [ function_type, tf.int32, ]) fn = building_blocks.Selection(tuple_ref, index=0) arg = building_blocks.Selection(tuple_ref, index=1) called_fn = building_blocks.Call(fn, arg) concrete_fn = building_blocks.Lambda( 'x', tf.int32, building_blocks.Reference('x', tf.int32)) concrete_arg = building_blocks.Data('a', tf.int32) arg_tuple = building_blocks.Tuple([concrete_fn, concrete_arg]) generated_structure = building_blocks.Block([('arg', arg_tuple)], called_fn) lambdas_and_blocks_removed, modified = compiler_transformations.remove_lambdas_and_blocks( generated_structure) self.assertTrue(modified) self.assertNoLambdasOrBlocks(lambdas_and_blocks_removed)
def test_returns_single_called_graph_with_selection_in_result(self): ref_to_tuple = building_blocks.Reference('var', [tf.int32, tf.int32]) first_tf_id = building_block_factory.create_compiled_identity( ref_to_tuple.type_signature) called_tf_id = building_blocks.Call(first_tf_id, ref_to_tuple) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) block_locals = [('call', called_tf_id)] block = building_blocks.Block( block_locals, building_blocks.Selection(ref_to_call, index=0)) tf_representing_block, _ = compiler_transformations.create_tensorflow_representing_block( block) self.assertEqual(tf_representing_block.type_signature, block.type_signature) self.assertIsInstance(tf_representing_block, building_blocks.Call) self.assertIsInstance(tf_representing_block.function, building_blocks.CompiledComputation) self.assertIsInstance(tf_representing_block.argument, building_blocks.Reference) self.assertEqual(tf_representing_block.argument.name, 'var')
def test_replaces_lambda_to_called_tf_block_with_replicated_lambda_arg_with_tf_block_of_same_type( self): sum_and_add_one = _create_compiled_computation( lambda x: x[0] + x[1] + 1, [tf.int32, tf.int32]) int_ref = building_blocks.Reference('x', tf.int32) tuple_of_ints = building_blocks.Tuple((int_ref, int_ref)) summed = building_blocks.Call(sum_and_add_one, tuple_of_ints) lambda_wrapper = building_blocks.Lambda('x', tf.int32, summed) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) self.assertEqual(exec_lambda(17), exec_tf(17))
def _create_before_and_after_broadcast_for_no_broadcast(tree): """Creates a before and after broadcast computations for the given `tree`. This function is intended to be used by `get_canonical_form_for_iterative_process` to create before and after broadcast computations for the given `tree` when there is no `intrinsic_defs.FEDERATED_BROADCAST` in `tree`. NOTE: This function does not assert that there is no `intrinsic_defs.FEDERATED_BROADCAST` in `tree`, the caller is expected to perform this check before calling this function. Args: tree: An instance of `building_blocks.ComputationBuildingBlock`. Returns: A pair of the form `(before, after)`, where each of `before` and `after` is a `tff_framework.ComputationBuildingBlock` that represents a part of the result as specified by `transformations.force_align_and_split_by_intrinsics`. """ name_generator = building_block_factory.unique_name_generator(tree) parameter_name = next(name_generator) empty_tuple = building_blocks.Tuple([]) value = building_block_factory.create_federated_value( empty_tuple, placements.SERVER) before_broadcast = building_blocks.Lambda(parameter_name, tree.type_signature.parameter, value) parameter_name = next(name_generator) type_signature = computation_types.FederatedType( before_broadcast.type_signature.result.member, placements.CLIENTS) parameter_type = computation_types.NamedTupleType( [tree.type_signature.parameter, type_signature]) ref = building_blocks.Reference(parameter_name, parameter_type) arg = building_blocks.Selection(ref, index=0) call = building_blocks.Call(tree, arg) after_broadcast = building_blocks.Lambda(ref.name, ref.type_signature, call) return before_broadcast, after_broadcast
def _wrap_computation_as_value( proto: pb.Computation, context_stack: context_stack_base.ContextStack) -> value_base.Value: """Wraps the given computation as a `tff.Value`. Args: proto: A pb.Computation. context_stack: The context stack to use. Returns: A `value_base.Value`. """ py_typecheck.check_type(proto, pb.Computation) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) compiled = building_blocks.CompiledComputation(proto) call = building_blocks.Call(compiled) federated_computation_context = context_stack.current ref = federated_computation_context.bind_computation_to_reference(call) return ValueImpl(ref, context_stack)
def _create_dummy_before_and_after_broadcast(comp): """Creates a before and after broadcast computations for the given `comp`. This function is intended to be used instead of `transformations.force_align_and_split_by_intrinsic` to generate dummy before and after computations, when there is no `intrinsic_defs.FEDERATED_BROADCAST` present in `comp`. Note: This function does not assert that there is no `intrinsic_defs.FEDERATED_BROADCAST` present in `comp`, the caller is expected to perform this check before calling this function. Args: comp: An instance of `building_blocks.ComputationBuildingBlock`. Returns: A pair of the form `(before, after)`, where each of `before` and `after` is a `tff_framework.ComputationBuildingBlock` that represents a part of the result as specified by `transformations.force_align_and_split_by_intrinsic`. """ name_generator = building_block_factory.unique_name_generator(comp) parameter_name = six.next(name_generator) empty_tuple = building_blocks.Tuple([]) federated_value_at_server = building_block_factory.create_federated_value( empty_tuple, placements.SERVER) before_broadcast = building_blocks.Lambda(parameter_name, comp.type_signature.parameter, federated_value_at_server) parameter_name = six.next(name_generator) type_signature = computation_types.FederatedType( before_broadcast.type_signature.result.member, placements.CLIENTS) parameter_type = computation_types.NamedTupleType( [comp.type_signature.parameter, type_signature]) ref = building_blocks.Reference(parameter_name, parameter_type) arg = building_blocks.Selection(ref, index=0) call = building_blocks.Call(comp, arg) after_broadcast = building_blocks.Lambda(ref.name, ref.type_signature, call) return before_broadcast, after_broadcast
def test_replaces_lambda_to_called_graph_with_tf_of_same_type(self): identity_tf_block_type = computation_types.TensorType(tf.int32) identity_tf_block = building_block_factory.create_compiled_identity( identity_tf_block_type) int_ref = building_blocks.Reference('x', tf.int32) called_tf_block = building_blocks.Call(identity_tf_block, int_ref) lambda_wrapper = building_blocks.Lambda('x', tf.int32, called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # FIXME(b/157172423) change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to( lambda_wrapper.type_signature) self.assertEqual(exec_lambda(2), exec_tf(2))