def test_identical_to_merge_tuple_intrinsics_with_different_intrinsics(
            self):
        called_intrinsic1 = test_utils.create_dummy_called_federated_aggregate(
            accumulate_parameter_name='a',
            merge_parameter_name='b',
            report_parameter_name='c',
            value_type=tf.int32)
        # These compare as not equal.
        called_intrinsic2 = test_utils.create_dummy_called_federated_aggregate(
            accumulate_parameter_name='x',
            merge_parameter_name='y',
            report_parameter_name='z',
            value_type=tf.float32)
        calls = building_blocks.Tuple((called_intrinsic1, called_intrinsic2))
        comp = calls

        deduped_and_merged_comp, deduped_modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)
        directly_merged_comp, directly_modified = transformations.merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)

        self.assertTrue(deduped_modified)
        self.assertTrue(directly_modified)
        self.assertEqual(deduped_and_merged_comp.formatted_representation(),
                         directly_merged_comp.formatted_representation())
    def test_aggregate_with_selection_from_block_by_name_results_in_single_aggregate(
            self):
        data = building_blocks.Reference(
            'a',
            computation_types.FederatedType(tf.int32,
                                            placement_literals.CLIENTS))
        tup_of_data = building_blocks.Tuple([('a', data), ('b', data)])
        block_holding_tup = building_blocks.Block([], tup_of_data)
        index_0_from_block = building_blocks.Selection(
            source=block_holding_tup, name='a')
        index_1_from_block = building_blocks.Selection(
            source=block_holding_tup, name='b')

        result = building_blocks.Data('aggregation_result', tf.int32)
        zero = building_blocks.Data('zero', tf.int32)
        accumulate = building_blocks.Lambda('accumulate_param',
                                            [tf.int32, tf.int32], result)
        merge = building_blocks.Lambda('merge_param', [tf.int32, tf.int32],
                                       result)
        report = building_blocks.Lambda('report_param', tf.int32, result)

        called_intrinsic0 = building_block_factory.create_federated_aggregate(
            index_0_from_block, zero, accumulate, merge, report)
        called_intrinsic1 = building_block_factory.create_federated_aggregate(
            index_1_from_block, zero, accumulate, merge, report)
        calls = building_blocks.Tuple((called_intrinsic0, called_intrinsic1))
        comp = calls

        deduped_and_merged_comp, deduped_modified = transformations.dedupe_and_merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)

        self.assertTrue(deduped_modified)

        fed_agg = []

        def _find_called_federated_aggregate(comp):
            if (isinstance(comp, building_blocks.Call)
                    and isinstance(comp.function, building_blocks.Intrinsic)
                    and comp.function.uri
                    == intrinsic_defs.FEDERATED_AGGREGATE.uri):
                fed_agg.append(comp.function)
            return comp, False

        transformation_utils.transform_postorder(
            deduped_and_merged_comp, _find_called_federated_aggregate)
        self.assertLen(fed_agg, 1)
        self.assertEqual(
            fed_agg[0].type_signature.parameter[0].compact_representation(),
            '{<int32>}@CLIENTS')
    def test_constructs_broadcast_of_tuple_with_one_element(self):
        called_intrinsic = test_utils.create_dummy_called_federated_broadcast()
        calls = building_blocks.Tuple((called_intrinsic, called_intrinsic))
        comp = calls

        transformed_comp, modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_BROADCAST.uri)

        federated_broadcast = []

        def _find_federated_broadcast(comp):
            if building_block_analysis.is_called_intrinsic(
                    comp, intrinsic_defs.FEDERATED_BROADCAST.uri):
                federated_broadcast.append(comp)
            return comp, False

        transformation_utils.transform_postorder(transformed_comp,
                                                 _find_federated_broadcast)

        self.assertTrue(modified)
        self.assertEqual(
            comp.compact_representation(),
            '<federated_broadcast(data),federated_broadcast(data)>')

        self.assertLen(federated_broadcast, 1)
        self.assertLen(federated_broadcast[0].type_signature.member, 1)
        self.assertEqual(
            transformed_comp.formatted_representation(), '(_var1 -> <\n'
            '  _var1[0],\n'
            '  _var1[0]\n'
            '>)((x -> <\n'
            '  x[0]\n'
            '>)((let\n'
            '  value=federated_broadcast(federated_apply(<\n'
            '    (arg -> <\n'
            '      arg\n'
            '    >),\n'
            '    <\n'
            '      data\n'
            '    >[0]\n'
            '  >))\n'
            ' in <\n'
            '  federated_map_all_equal(<\n'
            '    (arg -> arg[0]),\n'
            '    value\n'
            '  >)\n'
            '>)))')
    def test_noops_in_case_of_distinct_applies(self):
        called_intrinsic1 = test_utils.create_dummy_called_federated_apply(
            parameter_name='a', parameter_type=tf.int32)
        called_intrinsic2 = test_utils.create_dummy_called_federated_apply(
            parameter_name='a', parameter_type=tf.float32)
        calls = building_blocks.Tuple((called_intrinsic1, called_intrinsic2))
        comp = calls

        deduped_and_merged_comp, deduped_modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_APPLY.uri)
        directly_merged_comp, directly_modified = transformations.merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_APPLY.uri)

        self.assertTrue(deduped_modified)
        self.assertTrue(directly_modified)
        self.assertEqual(deduped_and_merged_comp.compact_representation(),
                         directly_merged_comp.compact_representation())
    def test_dedupe_noops_in_case_of_distinct_broadcasts(self):
        called_intrinsic1 = test_utils.create_dummy_called_federated_broadcast(
            tf.int32)
        called_intrinsic2 = test_utils.create_dummy_called_federated_broadcast(
            tf.float32)
        calls = building_blocks.Tuple((called_intrinsic1, called_intrinsic2))
        comp = calls

        deduped_and_merged_comp, deduped_modified = transformations.dedupe_and_merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_BROADCAST.uri)

        directly_merged_comp, directly_modified = tree_transformations.merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_BROADCAST.uri)

        self.assertTrue(deduped_modified)
        self.assertTrue(directly_modified)
        self.assertEqual(deduped_and_merged_comp.compact_representation(),
                         directly_merged_comp.compact_representation())
Beispiel #6
0
def _force_align_intrinsics_to_top_level_lambda(comp, uri):
    """Forcefully aligns `comp` by the intrinsics for the given `uri`.

  This function transforms `comp` by extracting, grouping, and potentially
  merging all the intrinsics for the given `uri`. The result of this
  transformation should contain exactly one instance of the intrinsic for the
  given `uri` that is bound only by the `parameter_name` of `comp`.

  Args:
    comp: The `building_blocks.Lambda` to align.
    uri: A Python `list` of URI of intrinsics.

  Returns:
    A new computation with the transformation applied or the original `comp`.
  """
    py_typecheck.check_type(comp, building_blocks.Lambda)
    py_typecheck.check_type(uri, list)
    for x in uri:
        py_typecheck.check_type(x, str)

    comp, _ = tree_transformations.uniquify_reference_names(comp)
    if not _can_extract_intrinsics_to_top_level_lambda(comp, uri):
        comp, _ = tree_transformations.replace_called_lambda_with_block(comp)
    comp = _inline_block_variables_required_to_align_intrinsics(comp, uri)
    comp, modified = _extract_intrinsics_to_top_level_lambda(comp, uri)
    if modified:
        if len(uri) > 1:
            comp, _ = _group_by_intrinsics_in_top_level_lambda(comp)
        modified = False
        for intrinsic_uri in uri:
            comp, transform_modified = transformations.dedupe_and_merge_tuple_intrinsics(
                comp, intrinsic_uri)
            if transform_modified:
                # Required because merging called intrinsics invokes building block
                # factories that do not name references uniquely.
                comp, _ = tree_transformations.uniquify_reference_names(comp)
            modified = modified or transform_modified
        if modified:
            # Required because merging called intrinsics will nest the called
            # intrinsics such that they can no longer be split.
            comp, _ = _extract_intrinsics_to_top_level_lambda(comp, uri)
    return comp
    def test_constructs_aggregate_of_tuple_with_one_element(self):
        called_intrinsic = test_utils.create_dummy_called_federated_aggregate(
            accumulate_parameter_name='a',
            merge_parameter_name='b',
            report_parameter_name='c')
        calls = building_blocks.Tuple((called_intrinsic, called_intrinsic))
        comp = calls

        transformed_comp, modified = compiler_transformations.dedupe_and_merge_tuple_intrinsics(
            comp, intrinsic_defs.FEDERATED_AGGREGATE.uri)

        federated_agg = []

        def _find_federated_aggregate(comp):
            if building_block_analysis.is_called_intrinsic(
                    comp, intrinsic_defs.FEDERATED_AGGREGATE.uri):
                federated_agg.append(comp)
            return comp, False

        transformation_utils.transform_postorder(transformed_comp,
                                                 _find_federated_aggregate)
        self.assertTrue(modified)
        self.assertLen(federated_agg, 1)
        self.assertLen(federated_agg[0].type_signature.member, 1)
        self.assertEqual(
            transformed_comp.formatted_representation(), '(_var1 -> <\n'
            '  _var1[0],\n'
            '  _var1[0]\n'
            '>)((x -> <\n'
            '  x[0]\n'
            '>)((let\n'
            '  value=federated_aggregate(<\n'
            '    federated_map(<\n'
            '      (arg -> <\n'
            '        arg\n'
            '      >),\n'
            '      <\n'
            '        data\n'
            '      >[0]\n'
            '    >),\n'
            '    <\n'
            '      data\n'
            '    >,\n'
            '    (let\n'
            '      _var1=<\n'
            '        (a -> data)\n'
            '      >\n'
            '     in (_var2 -> <\n'
            '      _var1[0](<\n'
            '        <\n'
            '          _var2[0][0],\n'
            '          _var2[1][0]\n'
            '        >\n'
            '      >[0])\n'
            '    >)),\n'
            '    (let\n'
            '      _var3=<\n'
            '        (b -> data)\n'
            '      >\n'
            '     in (_var4 -> <\n'
            '      _var3[0](<\n'
            '        <\n'
            '          _var4[0][0],\n'
            '          _var4[1][0]\n'
            '        >\n'
            '      >[0])\n'
            '    >)),\n'
            '    (let\n'
            '      _var5=<\n'
            '        (c -> data)\n'
            '      >\n'
            '     in (_var6 -> <\n'
            '      _var5[0](_var6[0])\n'
            '    >))\n'
            '  >)\n'
            ' in <\n'
            '  federated_apply(<\n'
            '    (arg -> arg[0]),\n'
            '    value\n'
            '  >)\n'
            '>)))')