def test_returns_tree(self):
        ip = get_iterative_process_for_sum_example_with_no_federated_secure_sum(
        )
        next_tree = building_blocks.ComputationBuildingBlock.from_proto(
            ip.next._computation_proto)
        next_tree = canonical_form_utils._replace_intrinsics_with_bodies(
            next_tree)

        before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_secure_sum(
            next_tree)

        before_federated_aggregate, after_federated_aggregate = (
            transformations.force_align_and_split_by_intrinsics(
                next_tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))
        self.assertIsInstance(before_aggregate, building_blocks.Lambda)
        self.assertIsInstance(before_aggregate.result, building_blocks.Struct)
        self.assertLen(before_aggregate.result, 2)

        # trees_equal will fail if computations refer to unbound references, so we
        # create a new dummy computation to bind them.
        unbound_refs_in_before_agg_result = transformation_utils.get_map_of_unbound_references(
            before_aggregate.result[0])[before_aggregate.result[0]]
        unbound_refs_in_before_fed_agg_result = transformation_utils.get_map_of_unbound_references(
            before_federated_aggregate.result)[
                before_federated_aggregate.result]

        dummy_data = building_blocks.Data('data',
                                          computation_types.AbstractType('T'))

        blk_binding_refs_in_before_agg = building_blocks.Block(
            [(name, dummy_data) for name in unbound_refs_in_before_agg_result],
            before_aggregate.result[0])
        blk_binding_refs_in_before_fed_agg = building_blocks.Block(
            [(name, dummy_data)
             for name in unbound_refs_in_before_fed_agg_result],
            before_federated_aggregate.result)

        self.assertTrue(
            tree_analysis.trees_equal(blk_binding_refs_in_before_agg,
                                      blk_binding_refs_in_before_fed_agg))

        # pyformat: disable
        self.assertEqual(
            before_aggregate.result[1].formatted_representation(), '<\n'
            '  federated_value_at_clients(<>),\n'
            '  <>\n'
            '>')
        # pyformat: enable

        self.assertIsInstance(after_aggregate, building_blocks.Lambda)
        self.assertIsInstance(after_aggregate.result, building_blocks.Call)

        self.assertTrue(
            tree_analysis.trees_equal(after_aggregate.result.function,
                                      after_federated_aggregate))

        # pyformat: disable
        self.assertEqual(
            after_aggregate.result.argument.formatted_representation(), '<\n'
            '  _var1[0],\n'
            '  _var1[1][0]\n'
            '>')
Example #2
0
 def test_returns_false_for_references_with_different_names(self):
   reference_1 = building_blocks.Reference('a', tf.int32)
   reference_2 = building_blocks.Reference('b', tf.int32)
   self.assertFalse(tree_analysis.trees_equal(reference_1, reference_2))
Example #3
0
 def test_returns_false_for_selections_with_different_indexes(self):
   ref_1 = building_blocks.Reference('a', [tf.int32, tf.int32])
   selection_1 = building_blocks.Selection(ref_1, index=0)
   ref_2 = building_blocks.Reference('a', [tf.int32, tf.int32])
   selection_2 = building_blocks.Selection(ref_2, index=1)
   self.assertFalse(tree_analysis.trees_equal(selection_1, selection_2))
Example #4
0
 def test_returns_true_for_lambdas_referring_to_same_unbound_variables(self):
   ref_to_x = building_blocks.Reference('x', tf.int32)
   fn_1 = building_blocks.Lambda('a', tf.int32, ref_to_x)
   fn_2 = building_blocks.Lambda('a', tf.int32, ref_to_x)
   self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
Example #5
0
 def test_returns_false_for_placements_with_literals(self):
   placement_1 = building_blocks.Placement(placements.CLIENTS)
   placement_2 = building_blocks.Placement(placements.SERVER)
   self.assertFalse(tree_analysis.trees_equal(placement_1, placement_2))
Example #6
0
 def test_returns_false_for_intrinsics_with_different_names(self):
   type_signature_1 = computation_types.TensorType(tf.int32)
   intrinsic_1 = building_blocks.Intrinsic('a', type_signature_1)
   type_signature_2 = computation_types.TensorType(tf.int32)
   intrinsic_2 = building_blocks.Intrinsic('b', type_signature_2)
   self.assertFalse(tree_analysis.trees_equal(intrinsic_1, intrinsic_2))
Example #7
0
 def test_returns_true_for_lambdas_representing_identical_functions(self):
   ref_1 = building_blocks.Reference('a', tf.int32)
   fn_1 = building_blocks.Lambda('a', ref_1.type_signature, ref_1)
   ref_2 = building_blocks.Reference('b', tf.int32)
   fn_2 = building_blocks.Lambda('b', ref_2.type_signature, ref_2)
   self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
 def test_returns_true_for_tuples(self):
     data_1 = building_blocks.Data('data', tf.int32)
     tuple_1 = building_blocks.Tuple([data_1, data_1])
     data_2 = building_blocks.Data('data', tf.int32)
     tuple_2 = building_blocks.Tuple([data_2, data_2])
     self.assertTrue(tree_analysis.trees_equal(tuple_1, tuple_2))
Example #9
0
 def test_raises_type_error(self):
   data = building_blocks.Data('data', tf.int32)
   with self.assertRaises(TypeError):
     tree_analysis.trees_equal(data, 0)
   with self.assertRaises(TypeError):
     tree_analysis.trees_equal(0, data)
 def test_returns_false_for_tuples_with_different_names(self):
     data_1 = building_blocks.Data('data', tf.int32)
     tuple_1 = building_blocks.Tuple([('a', data_1), ('b', data_1)])
     data_2 = building_blocks.Data('data', tf.float32)
     tuple_2 = building_blocks.Tuple([('c', data_2), ('d', data_2)])
     self.assertFalse(tree_analysis.trees_equal(tuple_1, tuple_2))
 def test_returns_false_for_tuples_with_different_elements(self):
     data_1 = building_blocks.Data('data', tf.int32)
     tuple_1 = building_blocks.Tuple([data_1, data_1])
     data_2 = building_blocks.Data('data', tf.float32)
     tuple_2 = building_blocks.Tuple([data_2, data_2])
     self.assertFalse(tree_analysis.trees_equal(tuple_1, tuple_2))
 def test_returns_true_for_intrinsics(self):
     intrinsic_1 = building_blocks.Intrinsic('intrinsic', tf.int32)
     intrinsic_2 = building_blocks.Intrinsic('intrinsic', tf.int32)
     self.assertTrue(tree_analysis.trees_equal(intrinsic_1, intrinsic_2))
 def test_returns_false_for_intrinsics_with_different_names(self):
     intrinsic_1 = building_blocks.Intrinsic('a', tf.int32)
     intrinsic_2 = building_blocks.Intrinsic('b', tf.int32)
     self.assertFalse(tree_analysis.trees_equal(intrinsic_1, intrinsic_2))
 def test_returns_true_for_compiled_computations_with_different_names(self):
     compiled_1 = building_block_factory.create_compiled_identity(
         tf.int32, 'a')
     compiled_2 = building_block_factory.create_compiled_identity(
         tf.int32, 'b')
     self.assertTrue(tree_analysis.trees_equal(compiled_1, compiled_2))
Example #15
0
 def test_returns_false_for_data_with_different_names(self):
   data_1 = building_blocks.Data('a', tf.int32)
   data_2 = building_blocks.Data('b', tf.int32)
   self.assertFalse(tree_analysis.trees_equal(data_1, data_2))
Example #16
0
 def test_returns_false_for_block_and_none(self):
   data = building_blocks.Data('data', tf.int32)
   self.assertFalse(tree_analysis.trees_equal(data, None))
   self.assertFalse(tree_analysis.trees_equal(None, data))
Example #17
0
 def test_returns_true_for_data(self):
   data_1 = building_blocks.Data('data', tf.int32)
   data_2 = building_blocks.Data('data', tf.int32)
   self.assertTrue(tree_analysis.trees_equal(data_1, data_2))
Example #18
0
 def test_returns_true_for_none_and_none(self):
   self.assertTrue(tree_analysis.trees_equal(None, None))
Example #19
0
 def test_returns_true_for_intrinsics(self):
   type_signature_1 = computation_types.TensorType(tf.int32)
   intrinsic_1 = building_blocks.Intrinsic('intrinsic', type_signature_1)
   type_signature_2 = computation_types.TensorType(tf.int32)
   intrinsic_2 = building_blocks.Intrinsic('intrinsic', type_signature_2)
   self.assertTrue(tree_analysis.trees_equal(intrinsic_1, intrinsic_2))
Example #20
0
 def test_returns_true_for_the_same_comp(self):
   data = building_blocks.Data('data', tf.int32)
   self.assertTrue(tree_analysis.trees_equal(data, data))
Example #21
0
 def test_returns_false_for_lambdas_with_different_parameter_types(self):
   ref_1 = building_blocks.Reference('a', tf.int32)
   fn_1 = building_blocks.Lambda(ref_1.name, ref_1.type_signature, ref_1)
   ref_2 = building_blocks.Reference('a', tf.float32)
   fn_2 = building_blocks.Lambda(ref_2.name, ref_2.type_signature, ref_2)
   self.assertFalse(tree_analysis.trees_equal(fn_1, fn_2))
Example #22
0
 def test_returns_false_for_comps_with_different_types(self):
   data = building_blocks.Data('data', tf.int32)
   ref = building_blocks.Reference('a', tf.int32)
   self.assertFalse(tree_analysis.trees_equal(data, ref))
   self.assertFalse(tree_analysis.trees_equal(ref, data))
Example #23
0
 def test_returns_true_for_lambdas(self):
   ref_1 = building_blocks.Reference('a', tf.int32)
   fn_1 = building_blocks.Lambda(ref_1.name, ref_1.type_signature, ref_1)
   ref_2 = building_blocks.Reference('a', tf.int32)
   fn_2 = building_blocks.Lambda(ref_2.name, ref_2.type_signature, ref_2)
   self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
Example #24
0
 def test_returns_false_for_blocks_with_different_results(self):
   data_1 = building_blocks.Data('data', tf.int32)
   comp_1 = building_blocks.Block([], data_1)
   data_2 = building_blocks.Data('data', tf.float32)
   comp_2 = building_blocks.Block([], data_2)
   self.assertFalse(tree_analysis.trees_equal(comp_1, comp_2))
Example #25
0
 def test_returns_true_for_placements(self):
   placement_1 = building_blocks.Placement(placements.CLIENTS)
   placement_2 = building_blocks.Placement(placements.CLIENTS)
   self.assertTrue(tree_analysis.trees_equal(placement_1, placement_2))
Example #26
0
 def test_returns_true_for_blocks_with_different_variable_names(self):
   data = building_blocks.Data('data', tf.int32)
   comp_1 = building_blocks.Block([('a', data)], data)
   comp_2 = building_blocks.Block([('b', data)], data)
   self.assertTrue(tree_analysis.trees_equal(comp_1, comp_2))
Example #27
0
 def test_returns_true_for_references(self):
   reference_1 = building_blocks.Reference('a', tf.int32)
   reference_2 = building_blocks.Reference('a', tf.int32)
   self.assertTrue(tree_analysis.trees_equal(reference_1, reference_2))
Example #28
0
 def test_returns_true_for_blocks(self):
   data_1 = building_blocks.Data('data', tf.int32)
   comp_1 = building_blocks.Block([('a', data_1)], data_1)
   data_2 = building_blocks.Data('data', tf.int32)
   comp_2 = building_blocks.Block([('a', data_2)], data_2)
   self.assertTrue(tree_analysis.trees_equal(comp_1, comp_2))
Example #29
0
 def test_returns_false_for_selections_with_differet_names(self):
   ref_1 = building_blocks.Reference('a', [('a', tf.int32), ('b', tf.int32)])
   selection_1 = building_blocks.Selection(ref_1, name='a')
   ref_2 = building_blocks.Reference('a', [('a', tf.int32), ('b', tf.int32)])
   selection_2 = building_blocks.Selection(ref_2, name='b')
   self.assertFalse(tree_analysis.trees_equal(selection_1, selection_2))
Example #30
0
 def test_returns_true_for_identical_graphs_with_nans(self):
     tf_comp1 = _create_tensorflow_graph_with_nan()
     tf_comp2 = _create_tensorflow_graph_with_nan()
     self.assertTrue(tree_analysis.trees_equal(tf_comp1, tf_comp2))