def test_splits_on_selected_intrinsic_broadcast(self):
     federated_broadcast = building_block_test_utils.create_whimsy_called_federated_broadcast(
     )
     called_intrinsics = building_blocks.Struct([federated_broadcast])
     comp = building_blocks.Lambda('a', tf.int32, called_intrinsics)
     call = building_block_factory.create_null_federated_broadcast()
     self.assert_splits_on(comp, call)
 def test_splits_on_intrinsic_noarg_function(self):
     federated_broadcast = building_block_test_utils.create_whimsy_called_federated_broadcast(
     )
     called_intrinsics = building_blocks.Struct([federated_broadcast])
     comp = building_blocks.Lambda(None, None, called_intrinsics)
     call = building_block_factory.create_null_federated_broadcast()
     self.assert_splits_on(comp, call)
 def test_splits_on_selected_intrinsic_nested_in_tuple_broadcast(self):
     first_broadcast = building_block_test_utils.create_whimsy_called_federated_broadcast(
     )
     packed_broadcast = building_blocks.Struct([
         building_blocks.Data('a', computation_types.at_server(tf.int32)),
         first_broadcast
     ])
     sel = building_blocks.Selection(packed_broadcast, index=0)
     second_broadcast = building_block_factory.create_federated_broadcast(
         sel)
     result = transformations.to_call_dominant(second_broadcast)
     comp = building_blocks.Lambda('a', tf.int32, result)
     call = building_block_factory.create_null_federated_broadcast()
     self.assert_splits_on(comp, call)
Exemple #4
0
 def test_does_not_find_aggregate_dependent_on_broadcast(self):
     broadcast = building_block_test_utils.create_whimsy_called_federated_broadcast(
     )
     value_type = broadcast.type_signature
     zero = building_blocks.Data('zero', value_type.member)
     accumulate_result = building_blocks.Data('accumulate_result',
                                              value_type.member)
     accumulate = building_blocks.Lambda(
         'accumulate_parameter', [value_type.member, value_type.member],
         accumulate_result)
     merge_result = building_blocks.Data('merge_result', value_type.member)
     merge = building_blocks.Lambda('merge_parameter',
                                    [value_type.member, value_type.member],
                                    merge_result)
     report_result = building_blocks.Data('report_result',
                                          value_type.member)
     report = building_blocks.Lambda('report_parameter', value_type.member,
                                     report_result)
     aggregate_dependent_on_broadcast = building_block_factory.create_federated_aggregate(
         broadcast, zero, accumulate, merge, report)
     tree_analysis.check_aggregate_not_dependent_on_aggregate(
         aggregate_dependent_on_broadcast)
Exemple #5
0
 def test_returns_false_with_unmatched_called_intrinsic(self):
     comp = building_block_test_utils.create_whimsy_called_federated_broadcast(
     )
     uri = intrinsic_defs.FEDERATED_MAP.uri
     self.assertFalse(tree_analysis.contains_called_intrinsic(comp, uri))
Exemple #6
0
 def test_returns_true_with_matching_uri(self):
     comp = building_block_test_utils.create_whimsy_called_federated_broadcast(
     )
     uri = intrinsic_defs.FEDERATED_BROADCAST.uri
     self.assertTrue(tree_analysis.contains_called_intrinsic(comp, uri))
Exemple #7
0
 def test_returns_true_with_none_uri(self):
     comp = building_block_test_utils.create_whimsy_called_federated_broadcast(
     )
     self.assertTrue(tree_analysis.contains_called_intrinsic(comp))
Exemple #8
0
    def test_returns_true_for_tuples(self):
        data_1 = building_blocks.Data('data', tf.int32)
        tuple_1 = building_blocks.Struct([data_1, data_1])
        data_2 = building_blocks.Data('data', tf.int32)
        tuple_2 = building_blocks.Struct([data_2, data_2])
        self.assertTrue(tree_analysis.trees_equal(tuple_1, tuple_2))

    def test_returns_true_for_identical_graphs_with_nans(self):
        tf_comp1 = _create_tensorflow_graph_with_nan()
        tf_comp2 = _create_tensorflow_graph_with_nan()
        self.assertTrue(tree_analysis.trees_equal(tf_comp1, tf_comp2))


non_aggregation_intrinsics = building_blocks.Struct([
    (None,
     building_block_test_utils.create_whimsy_called_federated_broadcast()),
    (None,
     building_block_test_utils.create_whimsy_called_federated_value(
         placements.CLIENTS))
])

unit = computation_types.StructType([])
trivial_aggregate = building_block_test_utils.create_whimsy_called_federated_aggregate(
    value_type=unit)
trivial_mean = building_block_test_utils.create_whimsy_called_federated_mean(
    unit)
trivial_sum = building_block_test_utils.create_whimsy_called_federated_sum(
    unit)
# TODO(b/120439632) Enable once federated_mean accepts structured weights.
# trivial_weighted_mean = ...
trivial_secure_sum = building_block_test_utils.create_whimsy_called_federated_secure_sum_bitwidth(