def test_identical_to_merge_tuple_intrinsics_with_different_intrinsics( self): called_intrinsic1 = test_utils.create_dummy_called_federated_aggregate( accumulate_parameter_name='a', merge_parameter_name='b', report_parameter_name='c', value_type=tf.int32) # These compare as not equal. called_intrinsic2 = test_utils.create_dummy_called_federated_aggregate( accumulate_parameter_name='x', merge_parameter_name='y', report_parameter_name='z', value_type=tf.float32) calls = building_blocks.Tuple((called_intrinsic1, called_intrinsic2)) comp = calls deduped_and_merged_comp, deduped_modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) directly_merged_comp, directly_modified = transformations.merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) self.assertTrue(deduped_modified) self.assertTrue(directly_modified) self.assertEqual(deduped_and_merged_comp.formatted_representation(), directly_merged_comp.formatted_representation())
def test_aggregate_with_selection_from_block_by_name_results_in_single_aggregate( self): data = building_blocks.Reference( 'a', computation_types.FederatedType(tf.int32, placement_literals.CLIENTS)) tup_of_data = building_blocks.Tuple([('a', data), ('b', data)]) block_holding_tup = building_blocks.Block([], tup_of_data) index_0_from_block = building_blocks.Selection( source=block_holding_tup, name='a') index_1_from_block = building_blocks.Selection( source=block_holding_tup, name='b') result = building_blocks.Data('aggregation_result', tf.int32) zero = building_blocks.Data('zero', tf.int32) accumulate = building_blocks.Lambda('accumulate_param', [tf.int32, tf.int32], result) merge = building_blocks.Lambda('merge_param', [tf.int32, tf.int32], result) report = building_blocks.Lambda('report_param', tf.int32, result) called_intrinsic0 = building_block_factory.create_federated_aggregate( index_0_from_block, zero, accumulate, merge, report) called_intrinsic1 = building_block_factory.create_federated_aggregate( index_1_from_block, zero, accumulate, merge, report) calls = building_blocks.Tuple((called_intrinsic0, called_intrinsic1)) comp = calls deduped_and_merged_comp, deduped_modified = transformations.dedupe_and_merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) self.assertTrue(deduped_modified) fed_agg = [] def _find_called_federated_aggregate(comp): if (isinstance(comp, building_blocks.Call) and isinstance(comp.function, building_blocks.Intrinsic) and comp.function.uri == intrinsic_defs.FEDERATED_AGGREGATE.uri): fed_agg.append(comp.function) return comp, False transformation_utils.transform_postorder( deduped_and_merged_comp, _find_called_federated_aggregate) self.assertLen(fed_agg, 1) self.assertEqual( fed_agg[0].type_signature.parameter[0].compact_representation(), '{<int32>}@CLIENTS')
def test_constructs_broadcast_of_tuple_with_one_element(self): called_intrinsic = test_utils.create_dummy_called_federated_broadcast() calls = building_blocks.Tuple((called_intrinsic, called_intrinsic)) comp = calls transformed_comp, modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_BROADCAST.uri) federated_broadcast = [] def _find_federated_broadcast(comp): if building_block_analysis.is_called_intrinsic( comp, intrinsic_defs.FEDERATED_BROADCAST.uri): federated_broadcast.append(comp) return comp, False transformation_utils.transform_postorder(transformed_comp, _find_federated_broadcast) self.assertTrue(modified) self.assertEqual( comp.compact_representation(), '<federated_broadcast(data),federated_broadcast(data)>') self.assertLen(federated_broadcast, 1) self.assertLen(federated_broadcast[0].type_signature.member, 1) self.assertEqual( transformed_comp.formatted_representation(), '(_var1 -> <\n' ' _var1[0],\n' ' _var1[0]\n' '>)((x -> <\n' ' x[0]\n' '>)((let\n' ' value=federated_broadcast(federated_apply(<\n' ' (arg -> <\n' ' arg\n' ' >),\n' ' <\n' ' data\n' ' >[0]\n' ' >))\n' ' in <\n' ' federated_map_all_equal(<\n' ' (arg -> arg[0]),\n' ' value\n' ' >)\n' '>)))')
def test_noops_in_case_of_distinct_applies(self): called_intrinsic1 = test_utils.create_dummy_called_federated_apply( parameter_name='a', parameter_type=tf.int32) called_intrinsic2 = test_utils.create_dummy_called_federated_apply( parameter_name='a', parameter_type=tf.float32) calls = building_blocks.Tuple((called_intrinsic1, called_intrinsic2)) comp = calls deduped_and_merged_comp, deduped_modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_APPLY.uri) directly_merged_comp, directly_modified = transformations.merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_APPLY.uri) self.assertTrue(deduped_modified) self.assertTrue(directly_modified) self.assertEqual(deduped_and_merged_comp.compact_representation(), directly_merged_comp.compact_representation())
def test_dedupe_noops_in_case_of_distinct_broadcasts(self): called_intrinsic1 = test_utils.create_dummy_called_federated_broadcast( tf.int32) called_intrinsic2 = test_utils.create_dummy_called_federated_broadcast( tf.float32) calls = building_blocks.Tuple((called_intrinsic1, called_intrinsic2)) comp = calls deduped_and_merged_comp, deduped_modified = transformations.dedupe_and_merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_BROADCAST.uri) directly_merged_comp, directly_modified = tree_transformations.merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_BROADCAST.uri) self.assertTrue(deduped_modified) self.assertTrue(directly_modified) self.assertEqual(deduped_and_merged_comp.compact_representation(), directly_merged_comp.compact_representation())
def _force_align_intrinsics_to_top_level_lambda(comp, uri): """Forcefully aligns `comp` by the intrinsics for the given `uri`. This function transforms `comp` by extracting, grouping, and potentially merging all the intrinsics for the given `uri`. The result of this transformation should contain exactly one instance of the intrinsic for the given `uri` that is bound only by the `parameter_name` of `comp`. Args: comp: The `building_blocks.Lambda` to align. uri: A Python `list` of URI of intrinsics. Returns: A new computation with the transformation applied or the original `comp`. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(uri, list) for x in uri: py_typecheck.check_type(x, str) comp, _ = tree_transformations.uniquify_reference_names(comp) if not _can_extract_intrinsics_to_top_level_lambda(comp, uri): comp, _ = tree_transformations.replace_called_lambda_with_block(comp) comp = _inline_block_variables_required_to_align_intrinsics(comp, uri) comp, modified = _extract_intrinsics_to_top_level_lambda(comp, uri) if modified: if len(uri) > 1: comp, _ = _group_by_intrinsics_in_top_level_lambda(comp) modified = False for intrinsic_uri in uri: comp, transform_modified = transformations.dedupe_and_merge_tuple_intrinsics( comp, intrinsic_uri) if transform_modified: # Required because merging called intrinsics invokes building block # factories that do not name references uniquely. comp, _ = tree_transformations.uniquify_reference_names(comp) modified = modified or transform_modified if modified: # Required because merging called intrinsics will nest the called # intrinsics such that they can no longer be split. comp, _ = _extract_intrinsics_to_top_level_lambda(comp, uri) return comp
def test_constructs_aggregate_of_tuple_with_one_element(self): called_intrinsic = test_utils.create_dummy_called_federated_aggregate( accumulate_parameter_name='a', merge_parameter_name='b', report_parameter_name='c') calls = building_blocks.Tuple((called_intrinsic, called_intrinsic)) comp = calls transformed_comp, modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics( comp, intrinsic_defs.FEDERATED_AGGREGATE.uri) federated_agg = [] def _find_federated_aggregate(comp): if building_block_analysis.is_called_intrinsic( comp, intrinsic_defs.FEDERATED_AGGREGATE.uri): federated_agg.append(comp) return comp, False transformation_utils.transform_postorder(transformed_comp, _find_federated_aggregate) self.assertTrue(modified) self.assertLen(federated_agg, 1) self.assertLen(federated_agg[0].type_signature.member, 1) self.assertEqual( transformed_comp.formatted_representation(), '(_var1 -> <\n' ' _var1[0],\n' ' _var1[0]\n' '>)((x -> <\n' ' x[0]\n' '>)((let\n' ' value=federated_aggregate(<\n' ' federated_map(<\n' ' (arg -> <\n' ' arg\n' ' >),\n' ' <\n' ' data\n' ' >[0]\n' ' >),\n' ' <\n' ' data\n' ' >,\n' ' (let\n' ' _var1=<\n' ' (a -> data)\n' ' >\n' ' in (_var2 -> <\n' ' _var1[0](<\n' ' <\n' ' _var2[0][0],\n' ' _var2[1][0]\n' ' >\n' ' >[0])\n' ' >)),\n' ' (let\n' ' _var3=<\n' ' (b -> data)\n' ' >\n' ' in (_var4 -> <\n' ' _var3[0](<\n' ' <\n' ' _var4[0][0],\n' ' _var4[1][0]\n' ' >\n' ' >[0])\n' ' >)),\n' ' (let\n' ' _var5=<\n' ' (c -> data)\n' ' >\n' ' in (_var6 -> <\n' ' _var5[0](_var6[0])\n' ' >))\n' ' >)\n' ' in <\n' ' federated_apply(<\n' ' (arg -> arg[0]),\n' ' value\n' ' >)\n' '>)))')