def test_replace_chained_federated_maps_does_not_replace_unchained_federated_maps( self): map_arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) map_arg = computation_building_blocks.Reference('arg', map_arg_type) inner_lambda = _create_lambda_to_add_one(map_arg.type_signature.member) inner_call = _create_call_to_federated_map(inner_lambda, map_arg) dummy_tuple = computation_building_blocks.Tuple([inner_call]) dummy_selection = computation_building_blocks.Selection(dummy_tuple, index=0) outer_lambda = _create_lambda_to_add_one( inner_call.function.type_signature.result.member) outer_call = _create_call_to_federated_map(outer_lambda, dummy_selection) map_lambda = computation_building_blocks.Lambda( map_arg.name, map_arg.type_signature, outer_call) comp = map_lambda uri = intrinsic_defs.FEDERATED_MAP.uri self.assertEqual(_get_number_of_intrinsics(comp, uri), 2) comp_impl = _to_comp(comp) self.assertEqual(comp_impl([(1)]), [3]) transformed_comp = transformations.replace_chained_federated_maps_with_federated_map( comp) self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 2) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl([(1)]), [3])
def test_replace_chained_federated_maps_with_different_arg_types(self): map_arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) map_arg = computation_building_blocks.Reference('arg_1', map_arg_type) inner_lambda = _create_lambda_to_cast(tf.int32, tf.float32) inner_call = _create_call_to_federated_map(inner_lambda, map_arg) outer_lambda = _create_lambda_to_add_one( inner_call.type_signature.member) outer_call = _create_call_to_federated_map(outer_lambda, inner_call) map_lambda = computation_building_blocks.Lambda( map_arg.name, map_arg.type_signature, outer_call) comp = map_lambda self.assertEqual( _get_number_of_intrinsics(comp, intrinsic_defs.FEDERATED_MAP.uri), 2) comp_impl = _to_comp(comp) self.assertEqual(comp_impl([(1)]), [2.0]) transformed_comp = transformations.replace_chained_federated_maps_with_federated_map( comp) self.assertEqual( _get_number_of_intrinsics(transformed_comp, intrinsic_defs.FEDERATED_MAP.uri), 1) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl([(1)]), [2.0])
def test_replace_chained_federated_maps_replaces_multiple_federated_maps( self): calling_arg_type = computation_types.FederatedType( tf.int32, placements.CLIENTS) calling_arg = computation_building_blocks.Reference( 'arg', calling_arg_type) arg_type = calling_arg.type_signature.member arg = calling_arg for _ in range(10): lam = _create_lambda_to_add_one(arg_type) call = _create_call_to_federated_map(lam, arg) arg_type = call.function.type_signature.result.member arg = call calling_lambda = computation_building_blocks.Lambda( calling_arg.name, calling_arg.type_signature, call) comp = calling_lambda uri = intrinsic_defs.FEDERATED_MAP.uri self.assertEqual(_get_number_of_intrinsics(comp, uri), 10) comp_impl = _to_comp(comp) self.assertEqual(comp_impl([(1)]), [11]) transformed_comp = transformations.replace_chained_federated_maps_with_federated_map( comp) self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 1) transformed_comp_impl = _to_comp(transformed_comp) self.assertEqual(transformed_comp_impl([(1)]), [11])
def test_replace_chained_federated_maps_does_not_replace_one_federated_map( self): fn = _create_lambda_to_identity(tf.int32) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) arg = computation_building_blocks.Data('x', arg_type) call = _create_called_federated_map(fn, arg) comp = call transformed_comp = transformations.replace_chained_federated_maps_with_federated_map( comp) self.assertEqual(transformed_comp.tff_repr, comp.tff_repr) self.assertEqual(transformed_comp.tff_repr, 'federated_map(<(arg -> arg),x>)')
def test_replace_chained_federated_maps_does_not_replace_separated_federated_maps( self): fn_1 = _create_lambda_to_identity(tf.int32) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) arg = computation_building_blocks.Data('x', arg_type) call_1 = _create_called_federated_map(fn_1, arg) block = _create_dummy_block(call_1) fn_2 = _create_lambda_to_identity(tf.int32) call_2 = _create_called_federated_map(fn_2, block) comp = call_2 transformed_comp = transformations.replace_chained_federated_maps_with_federated_map( comp) self.assertEqual(transformed_comp.tff_repr, comp.tff_repr) self.assertEqual( transformed_comp.tff_repr, 'federated_map(<(arg -> arg),(let local=data in federated_map(<(arg -> arg),x>))>)' )
def test_replace_chained_federated_maps_replaces_federated_maps_with_different_types( self): fn_1 = _create_lambda_to_dummy_cast(tf.int32, tf.float32) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) arg = computation_building_blocks.Reference('x', arg_type) call_1 = _create_called_federated_map(fn_1, arg) fn_2 = _create_lambda_to_identity(tf.float32) call_2 = _create_called_federated_map(fn_2, call_1) comp = call_2 transformed_comp = transformations.replace_chained_federated_maps_with_federated_map( comp) self.assertEqual( comp.tff_repr, 'federated_map(<(arg -> arg),federated_map(<(arg -> data),x>)>)') self.assertEqual( transformed_comp.tff_repr, 'federated_map(<(arg -> (arg -> arg)((arg -> data)(arg))),x>)')
def test_replace_chained_federated_maps_replaces_nested_federated_maps( self): fn = _create_lambda_to_identity(tf.int32) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) arg = computation_building_blocks.Data('x', arg_type) call = _create_chained_called_federated_map(fn, arg, 2) block = _create_dummy_block(call) comp = block transformed_comp = transformations.replace_chained_federated_maps_with_federated_map( comp) self.assertEqual( comp.tff_repr, '(let local=data in federated_map(<(arg -> arg),federated_map(<(arg -> arg),x>)>))' ) self.assertEqual( transformed_comp.tff_repr, '(let local=data in federated_map(<(arg -> (arg -> arg)((arg -> arg)(arg))),x>))' )
def test_replace_chained_federated_maps_raises_type_error(self): with self.assertRaises(TypeError): transformations.replace_chained_federated_maps_with_federated_map( None)