def test_construct_setattr_named_tuple_type_replaces_single_element(self): good_type = computation_types.NamedTupleType([('a', tf.int32), ('b', tf.bool)]) value_comp = computation_building_blocks.Data('x', tf.int32) lam = computation_constructing_utils.construct_named_tuple_setattr_lambda( good_type, 'a', value_comp) self.assertEqual( lam.tff_repr, '(let value_comp_placeholder=x in (lambda_arg -> <a=value_comp_placeholder,b=lambda_arg[1]>))' )
def test_returns_federated_map(self): ref = computation_building_blocks.Reference('x', tf.int32) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, ref) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS, False) arg = computation_building_blocks.Data('y', arg_type) comp = computation_constructing_utils.create_federated_map(fn, arg) self.assertEqual(comp.tff_repr, 'federated_map(<(x -> x),y>)') self.assertEqual(str(comp.type_signature), '{int32}@CLIENTS')
def test_simple_block_inlining(self): test_arg = computation_building_blocks.Data('test_data', tf.int32) result = computation_building_blocks.Reference('test_x', test_arg.type_signature) simple_block = computation_building_blocks.Block( [('test_x', test_arg)], result) self.assertEqual(str(simple_block), '(let test_x=test_data in test_x)') inlined = transformations.inline_blocks_with_n_referenced_locals( simple_block) self.assertEqual(str(inlined), '(let in test_data)')
def test_replace_called_lambda_replaces_called_lambda(self): fn = _create_lambda_to_identity(tf.int32) arg = computation_building_blocks.Data('x', tf.int32) call = computation_building_blocks.Call(fn, arg) comp = call transformed_comp = transformations.replace_called_lambda_with_block( comp) self.assertEqual(comp.tff_repr, '(arg -> arg)(x)') self.assertEqual(transformed_comp.tff_repr, '(let arg=x in arg)')
def test_construct_setattr_named_tuple_type_leaves_type_signature_unchanged( self): good_type = computation_types.NamedTupleType([('a', tf.int32), (None, tf.float32), ('b', tf.bool)]) value_comp = computation_building_blocks.Data('x', tf.int32) lam = computation_constructing_utils.construct_named_tuple_setattr_lambda( good_type, 'a', value_comp) self.assertTrue( type_utils.are_equivalent_types(lam.type_signature.parameter, lam.type_signature.result))
def test_returns_string_for_data(self): comp = computation_building_blocks.Data('data', tf.int32) compact_string = computation_building_blocks.compact_representation( comp) self.assertEqual(compact_string, 'data') formatted_string = computation_building_blocks.formatted_representation( comp) self.assertEqual(formatted_string, 'data') structural_string = computation_building_blocks.structural_representation( comp) self.assertEqual(structural_string, 'data')
def test_raises_type_error_with_none_zero(self): value_type = computation_types.FederatedType(tf.int32, placements.CLIENTS, False) value = computation_building_blocks.Data('v', value_type) accumulate_type = computation_types.NamedTupleType( (tf.int32, tf.int32)) accumulate_result = computation_building_blocks.Data('a', tf.int32) accumulate = computation_building_blocks.Lambda( 'x', accumulate_type, accumulate_result) merge_type = computation_types.NamedTupleType((tf.int32, tf.int32)) merge_result = computation_building_blocks.Data('m', tf.int32) merge = computation_building_blocks.Lambda('x', merge_type, merge_result) report_ref = computation_building_blocks.Reference('r', tf.int32) report = computation_building_blocks.Lambda(report_ref.name, report_ref.type_signature, report_ref) with self.assertRaises(TypeError): computation_constructing_utils.create_federated_aggregate( value, None, accumulate, merge, report)
def test_federated_setattr_call_fails_on_none_value(self): named_tuple_type = computation_types.NamedTupleType([('a', tf.int32), (None, tf.float32), ('b', tf.bool)]) good_type = computation_types.FederatedType(named_tuple_type, placement_literals.CLIENTS) acceptable_comp = computation_building_blocks.Data('data', good_type) with self.assertRaises(TypeError): _ = computation_constructing_utils.construct_federated_setattr_call( acceptable_comp, 'a', None)
def test_remove_mapped_or_applied_identity_removes_identity( self, uri, type_spec, comp_factory): fn = _create_lambda_to_identity(tf.int32) arg = computation_building_blocks.Data('x', type_spec) call = comp_factory(fn, arg) comp = call transformed_comp = transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(comp.tff_repr, '{}(<(arg -> arg),x>)'.format(uri)) self.assertEqual(transformed_comp.tff_repr, 'x')
def test_returns_string_for_tuple_with_no_names(self): data = computation_building_blocks.Data('data', tf.int32) comp = computation_building_blocks.Tuple((data, data)) compact_string = comp.compact_representation() self.assertEqual(compact_string, '<data,data>') formatted_string = comp.formatted_representation() # pyformat: disable self.assertEqual(formatted_string, '<\n' ' data,\n' ' data\n' '>') # pyformat: enable structural_string = comp.structural_representation() # pyformat: disable self.assertEqual(structural_string, 'Tuple\n' '|\n' '[data, data]')
def test_remove_mapped_or_applied_identity_does_not_remove_called_lambda( self): fn = _create_lambda_to_identity(tf.int32) arg = computation_building_blocks.Data('x', tf.int32) call = computation_building_blocks.Call(fn, arg) comp = call transformed_comp = transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(comp.tff_repr, '(arg -> arg)(x)') self.assertEqual(transformed_comp.tff_repr, '(arg -> arg)(x)')
def test_does_not_find_aggregate_dependent_on_broadcast(self): broadcast = computation_test_utils.create_dummy_called_federated_broadcast() value_type = broadcast.type_signature zero = computation_building_blocks.Data('zero', value_type.member) accumulate_result = computation_building_blocks.Data( 'accumulate_result', value_type.member) accumulate = computation_building_blocks.Lambda( 'accumulate_parameter', [value_type.member, value_type.member], accumulate_result) merge_result = computation_building_blocks.Data('merge_result', value_type.member) merge = computation_building_blocks.Lambda( 'merge_parameter', [value_type.member, value_type.member], merge_result) report_result = computation_building_blocks.Data('report_result', value_type.member) report = computation_building_blocks.Lambda('report_parameter', value_type.member, report_result) aggregate_dependent_on_broadcast = computation_constructing_utils.create_federated_aggregate( broadcast, zero, accumulate, merge, report) tree_analysis.check_broadcast_not_dependent_on_aggregate( aggregate_dependent_on_broadcast)
def test_no_inlining_if_referenced_twice(self): test_arg = computation_building_blocks.Data('test_data', tf.int32) ref1 = computation_building_blocks.Reference('test_x', test_arg.type_signature) ref2 = computation_building_blocks.Reference('test_x', test_arg.type_signature) result = computation_building_blocks.Tuple([ref1, ref2]) simple_block = computation_building_blocks.Block( [('test_x', test_arg)], result) self.assertEqual(str(simple_block), '(let test_x=test_data in <test_x,test_x>)') inlined = transformations.inline_blocks_with_n_referenced_locals( simple_block) self.assertEqual(str(inlined), str(simple_block))
def test_replace_called_lambda_does_not_replace_separated_called_lambda( self): fn = _create_lambda_to_identity(tf.int32) block = _create_dummy_block(fn) arg = computation_building_blocks.Data('x', tf.int32) call = computation_building_blocks.Call(block, arg) comp = call transformed_comp = transformations.replace_called_lambda_with_block( comp) self.assertEqual(transformed_comp.tff_repr, comp.tff_repr) self.assertEqual(transformed_comp.tff_repr, '(let local=data in (arg -> arg))(x)')
def create_dummy_block(comp, variable_name, variable_type=tf.int32): r"""Returns an identity block. Block / \ [x=data] Comp Args: comp: The computation to use as the result. variable_name: The name of the variable. variable_type: The type of the variable. """ data = computation_building_blocks.Data('data', variable_type) return computation_building_blocks.Block([(variable_name, data)], comp)
def create_identity_block_with_dummy_data(variable_name, variable_type=tf.int32): r"""Returns an identity block with a dummy `Data` computation. Block / \ [x=data] Ref(x) Args: variable_name: The name of the variable. variable_type: The type of the variable. """ data = computation_building_blocks.Data('data', variable_type) return create_identity_block(variable_name, data)
def create_dummy_called_federated_broadcast(value_type=tf.int32): r"""Returns a dummy called federated broadcast. Call / \ federated_map data Args: value_type: The type of the parameter. """ federated_type = computation_types.FederatedType(value_type, placements.SERVER) value = computation_building_blocks.Data('data', federated_type) return computation_constructing_utils.create_federated_broadcast(value)
def test_no_reduce_separated_lambda_and_call(self): @computations.federated_computation(tf.int32) def foo(x): return x comp = _to_building_block(foo) block_wrapped_comp = computation_building_blocks.Block([], comp) test_arg = computation_building_blocks.Data('test', tf.int32) called_block = computation_building_blocks.Call( block_wrapped_comp, test_arg) lambda_reduced_comp = transformations.replace_called_lambdas_with_block( called_block) self.assertEqual(str(called_block), '(let in (foo_arg -> foo_arg))(test)') self.assertEqual(str(called_block), str(lambda_reduced_comp))
def test_basic_functionality_of_data_class(self): x = computation_building_blocks.Data( '/tmp/mydata', computation_types.SequenceType(tf.int32)) self.assertEqual(str(x.type_signature), 'int32*') self.assertEqual(x.uri, '/tmp/mydata') self.assertEqual( repr(x), 'Data(\'/tmp/mydata\', SequenceType(TensorType(tf.int32)))') self.assertEqual(x.tff_repr, '/tmp/mydata') x_proto = x.proto self.assertEqual(type_serialization.deserialize_type(x_proto.type), x.type_signature) self.assertEqual(x_proto.WhichOneof('computation'), 'data') self.assertEqual(x_proto.data.uri, x.uri) self._serialize_deserialize_roundtrip_test(x)
def test_replace_chained_federated_maps_does_not_replace_one_federated_map( self): fn = _create_lambda_to_identity(tf.int32) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) arg = computation_building_blocks.Data('x', arg_type) call = _create_called_federated_map(fn, arg) comp = call transformed_comp = transformations.replace_chained_federated_maps_with_federated_map( comp) self.assertEqual(transformed_comp.tff_repr, comp.tff_repr) self.assertEqual(transformed_comp.tff_repr, 'federated_map(<(arg -> arg),x>)')
def test_returns_federated_aggregate(self): value_type = computation_types.FederatedType(tf.int32, placements.CLIENTS, False) value = computation_building_blocks.Data('v', value_type) zero = computation_building_blocks.Data('z', tf.int32) accumulate_type = computation_types.NamedTupleType( (tf.int32, tf.int32)) accumulate_result = computation_building_blocks.Data('a', tf.int32) accumulate = computation_building_blocks.Lambda( 'x', accumulate_type, accumulate_result) merge_type = computation_types.NamedTupleType((tf.int32, tf.int32)) merge_result = computation_building_blocks.Data('m', tf.int32) merge = computation_building_blocks.Lambda('x', merge_type, merge_result) report_ref = computation_building_blocks.Reference('r', tf.int32) report = computation_building_blocks.Lambda(report_ref.name, report_ref.type_signature, report_ref) comp = computation_constructing_utils.create_federated_aggregate( value, zero, accumulate, merge, report) self.assertEqual( comp.tff_repr, 'federated_aggregate(<v,z,(x -> a),(x -> m),(r -> r)>)') self.assertEqual(str(comp.type_signature), 'int32@SERVER')
def create_dummy_called_sequence_map(parameter_name, parameter_type=tf.int32): r"""Returns a dummy called sequence map. Call / \ sequence_map data Args: parameter_name: The name of the parameter. parameter_type: The type of the parameter. """ fn = create_identity_function(parameter_name, parameter_type) arg_type = computation_types.SequenceType(parameter_type) arg = computation_building_blocks.Data('data', arg_type) return computation_constructing_utils.create_sequence_map(fn, arg)
def test_remove_mapped_or_applied_identity_removes_identity( self, uri, data_type): data = computation_building_blocks.Data('x', data_type) identity_arg = computation_building_blocks.Reference('arg', tf.float32) identity_lam = computation_building_blocks.Lambda( 'arg', tf.float32, identity_arg) arg_tuple = computation_building_blocks.Tuple([identity_lam, data]) function_type = computation_types.FunctionType( [arg_tuple.type_signature[0], arg_tuple.type_signature[1]], arg_tuple.type_signature[1]) intrinsic = computation_building_blocks.Intrinsic(uri, function_type) call = computation_building_blocks.Call(intrinsic, arg_tuple) self.assertEqual(str(call), '{}(<(arg -> arg),x>)'.format(uri)) reduced = transformations.remove_mapped_or_applied_identity(call) self.assertEqual(str(reduced), 'x')
def test_remove_mapped_or_applied_identity_removes_multiple_identities( self): fn = _create_lambda_to_identity(tf.int32) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) arg = computation_building_blocks.Data('x', arg_type) call = _create_chained_called_federated_map(fn, arg, 2) comp = call transformed_comp = transformations.remove_mapped_or_applied_identity( comp) self.assertEqual( comp.tff_repr, 'federated_map(<(arg -> arg),federated_map(<(arg -> arg),x>)>)') self.assertEqual(transformed_comp.tff_repr, 'x')
def test_remove_mapped_or_applied_identity_does_not_remove_other_intrinsic( self): fn = _create_lambda_to_identity(tf.int32) arg = computation_building_blocks.Data('x', tf.int32) intrinsic_type = computation_types.FunctionType( [fn.type_signature, arg.type_signature], arg.type_signature) intrinsic = computation_building_blocks.Intrinsic( 'dummy', intrinsic_type) tup = computation_building_blocks.Tuple((fn, arg)) call = computation_building_blocks.Call(intrinsic, tup) comp = call transformed_comp = transformations.remove_mapped_or_applied_identity( comp) self.assertEqual(comp.tff_repr, 'dummy(<(arg -> arg),x>)') self.assertEqual(transformed_comp.tff_repr, 'dummy(<(arg -> arg),x>)')
def test_is_anon_tuple_with_py_container(self): self.assertTrue( type_utils.is_anon_tuple_with_py_container( anonymous_tuple.AnonymousTuple([('a', 0.0)]), computation_types.NamedTupleTypeWithPyContainerType( [('a', tf.float32)], dict))) self.assertFalse( type_utils.is_anon_tuple_with_py_container( value_impl.ValueImpl( computation_building_blocks.Data('nothing', tf.int32), context_stack_impl.context_stack), computation_types.NamedTupleTypeWithPyContainerType( [('a', tf.float32)], dict))) self.assertFalse( type_utils.is_anon_tuple_with_py_container( anonymous_tuple.AnonymousTuple([('a', 0.0)]), computation_types.NamedTupleType([('a', tf.float32)])))
def test_replace_intrinsic_replaces_multiple_intrinsics(self): fn = _create_lambda_to_dummy_intrinsic(tf.int32) arg = computation_building_blocks.Data('x', tf.int32) call = _create_chained_call(fn, arg, 2) comp = call uri = 'dummy' body = lambda x: x transformed_comp = transformations.replace_intrinsic_with_callable( comp, uri, body, context_stack_impl.context_stack) self.assertEqual(comp.tff_repr, '(arg -> dummy(arg))((arg -> dummy(arg))(x))') self.assertEqual( transformed_comp.tff_repr, '(arg -> (dummy_arg -> dummy_arg)(arg))((arg -> (dummy_arg -> dummy_arg)(arg))(x))' )
def test_returns_string_for_call_with_arg(self): ref = computation_building_blocks.Reference('a', tf.int32) fn = computation_building_blocks.Lambda(ref.name, ref.type_signature, ref) arg = computation_building_blocks.Data('data', tf.int32) comp = computation_building_blocks.Call(fn, arg) compact_string = comp.compact_representation() self.assertEqual(compact_string, '(a -> a)(data)') formatted_string = comp.formatted_representation() self.assertEqual(formatted_string, '(a -> a)(data)') structural_string = comp.structural_representation() # pyformat: disable self.assertEqual( structural_string, ' Call\n' ' / \\\n' 'Lambda(a) data\n' '|\n' 'Ref(a)')
def create_dummy_called_federated_map(parameter_name, parameter_type=tf.int32): r"""Returns a dummy called federated map. Call / \ federated_map Tuple | [Lambda(x), data] | Ref(x) Args: parameter_name: The name of the parameter. parameter_type: The type of the parameter. """ fn = create_identity_function(parameter_name, parameter_type) arg_type = computation_types.FederatedType(parameter_type, placements.CLIENTS) arg = computation_building_blocks.Data('data', arg_type) return computation_constructing_utils.create_federated_map(fn, arg)
def test_replace_chained_federated_maps_does_not_replace_separated_federated_maps( self): fn_1 = _create_lambda_to_identity(tf.int32) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) arg = computation_building_blocks.Data('x', arg_type) call_1 = _create_called_federated_map(fn_1, arg) block = _create_dummy_block(call_1) fn_2 = _create_lambda_to_identity(tf.int32) call_2 = _create_called_federated_map(fn_2, block) comp = call_2 transformed_comp = transformations.replace_chained_federated_maps_with_federated_map( comp) self.assertEqual(transformed_comp.tff_repr, comp.tff_repr) self.assertEqual( transformed_comp.tff_repr, 'federated_map(<(arg -> arg),(let local=data in federated_map(<(arg -> arg),x>))>)' )