def test_returns_tree(self): ip = get_iterative_process_for_sum_example_with_no_federated_secure_sum( ) next_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) next_tree = canonical_form_utils._replace_intrinsics_with_bodies( next_tree) before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_secure_sum( next_tree) before_federated_aggregate, after_federated_aggregate = ( transformations.force_align_and_split_by_intrinsics( next_tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) self.assertIsInstance(before_aggregate, building_blocks.Lambda) self.assertIsInstance(before_aggregate.result, building_blocks.Struct) self.assertLen(before_aggregate.result, 2) # trees_equal will fail if computations refer to unbound references, so we # create a new dummy computation to bind them. unbound_refs_in_before_agg_result = transformation_utils.get_map_of_unbound_references( before_aggregate.result[0])[before_aggregate.result[0]] unbound_refs_in_before_fed_agg_result = transformation_utils.get_map_of_unbound_references( before_federated_aggregate.result)[ before_federated_aggregate.result] dummy_data = building_blocks.Data('data', computation_types.AbstractType('T')) blk_binding_refs_in_before_agg = building_blocks.Block( [(name, dummy_data) for name in unbound_refs_in_before_agg_result], before_aggregate.result[0]) blk_binding_refs_in_before_fed_agg = building_blocks.Block( [(name, dummy_data) for name in unbound_refs_in_before_fed_agg_result], before_federated_aggregate.result) self.assertTrue( tree_analysis.trees_equal(blk_binding_refs_in_before_agg, blk_binding_refs_in_before_fed_agg)) # pyformat: disable self.assertEqual( before_aggregate.result[1].formatted_representation(), '<\n' ' federated_value_at_clients(<>),\n' ' <>\n' '>') # pyformat: enable self.assertIsInstance(after_aggregate, building_blocks.Lambda) self.assertIsInstance(after_aggregate.result, building_blocks.Call) self.assertTrue( tree_analysis.trees_equal(after_aggregate.result.function, after_federated_aggregate)) # pyformat: disable self.assertEqual( after_aggregate.result.argument.formatted_representation(), '<\n' ' _var1[0],\n' ' _var1[1][0]\n' '>')
def test_returns_false_for_references_with_different_names(self): reference_1 = building_blocks.Reference('a', tf.int32) reference_2 = building_blocks.Reference('b', tf.int32) self.assertFalse(tree_analysis.trees_equal(reference_1, reference_2))
def test_returns_false_for_selections_with_different_indexes(self): ref_1 = building_blocks.Reference('a', [tf.int32, tf.int32]) selection_1 = building_blocks.Selection(ref_1, index=0) ref_2 = building_blocks.Reference('a', [tf.int32, tf.int32]) selection_2 = building_blocks.Selection(ref_2, index=1) self.assertFalse(tree_analysis.trees_equal(selection_1, selection_2))
def test_returns_true_for_lambdas_referring_to_same_unbound_variables(self): ref_to_x = building_blocks.Reference('x', tf.int32) fn_1 = building_blocks.Lambda('a', tf.int32, ref_to_x) fn_2 = building_blocks.Lambda('a', tf.int32, ref_to_x) self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
def test_returns_false_for_placements_with_literals(self): placement_1 = building_blocks.Placement(placements.CLIENTS) placement_2 = building_blocks.Placement(placements.SERVER) self.assertFalse(tree_analysis.trees_equal(placement_1, placement_2))
def test_returns_false_for_intrinsics_with_different_names(self): type_signature_1 = computation_types.TensorType(tf.int32) intrinsic_1 = building_blocks.Intrinsic('a', type_signature_1) type_signature_2 = computation_types.TensorType(tf.int32) intrinsic_2 = building_blocks.Intrinsic('b', type_signature_2) self.assertFalse(tree_analysis.trees_equal(intrinsic_1, intrinsic_2))
def test_returns_true_for_lambdas_representing_identical_functions(self): ref_1 = building_blocks.Reference('a', tf.int32) fn_1 = building_blocks.Lambda('a', ref_1.type_signature, ref_1) ref_2 = building_blocks.Reference('b', tf.int32) fn_2 = building_blocks.Lambda('b', ref_2.type_signature, ref_2) self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
def test_returns_true_for_tuples(self): data_1 = building_blocks.Data('data', tf.int32) tuple_1 = building_blocks.Tuple([data_1, data_1]) data_2 = building_blocks.Data('data', tf.int32) tuple_2 = building_blocks.Tuple([data_2, data_2]) self.assertTrue(tree_analysis.trees_equal(tuple_1, tuple_2))
def test_raises_type_error(self): data = building_blocks.Data('data', tf.int32) with self.assertRaises(TypeError): tree_analysis.trees_equal(data, 0) with self.assertRaises(TypeError): tree_analysis.trees_equal(0, data)
def test_returns_false_for_tuples_with_different_names(self): data_1 = building_blocks.Data('data', tf.int32) tuple_1 = building_blocks.Tuple([('a', data_1), ('b', data_1)]) data_2 = building_blocks.Data('data', tf.float32) tuple_2 = building_blocks.Tuple([('c', data_2), ('d', data_2)]) self.assertFalse(tree_analysis.trees_equal(tuple_1, tuple_2))
def test_returns_false_for_tuples_with_different_elements(self): data_1 = building_blocks.Data('data', tf.int32) tuple_1 = building_blocks.Tuple([data_1, data_1]) data_2 = building_blocks.Data('data', tf.float32) tuple_2 = building_blocks.Tuple([data_2, data_2]) self.assertFalse(tree_analysis.trees_equal(tuple_1, tuple_2))
def test_returns_true_for_intrinsics(self): intrinsic_1 = building_blocks.Intrinsic('intrinsic', tf.int32) intrinsic_2 = building_blocks.Intrinsic('intrinsic', tf.int32) self.assertTrue(tree_analysis.trees_equal(intrinsic_1, intrinsic_2))
def test_returns_false_for_intrinsics_with_different_names(self): intrinsic_1 = building_blocks.Intrinsic('a', tf.int32) intrinsic_2 = building_blocks.Intrinsic('b', tf.int32) self.assertFalse(tree_analysis.trees_equal(intrinsic_1, intrinsic_2))
def test_returns_true_for_compiled_computations_with_different_names(self): compiled_1 = building_block_factory.create_compiled_identity( tf.int32, 'a') compiled_2 = building_block_factory.create_compiled_identity( tf.int32, 'b') self.assertTrue(tree_analysis.trees_equal(compiled_1, compiled_2))
def test_returns_false_for_data_with_different_names(self): data_1 = building_blocks.Data('a', tf.int32) data_2 = building_blocks.Data('b', tf.int32) self.assertFalse(tree_analysis.trees_equal(data_1, data_2))
def test_returns_false_for_block_and_none(self): data = building_blocks.Data('data', tf.int32) self.assertFalse(tree_analysis.trees_equal(data, None)) self.assertFalse(tree_analysis.trees_equal(None, data))
def test_returns_true_for_data(self): data_1 = building_blocks.Data('data', tf.int32) data_2 = building_blocks.Data('data', tf.int32) self.assertTrue(tree_analysis.trees_equal(data_1, data_2))
def test_returns_true_for_none_and_none(self): self.assertTrue(tree_analysis.trees_equal(None, None))
def test_returns_true_for_intrinsics(self): type_signature_1 = computation_types.TensorType(tf.int32) intrinsic_1 = building_blocks.Intrinsic('intrinsic', type_signature_1) type_signature_2 = computation_types.TensorType(tf.int32) intrinsic_2 = building_blocks.Intrinsic('intrinsic', type_signature_2) self.assertTrue(tree_analysis.trees_equal(intrinsic_1, intrinsic_2))
def test_returns_true_for_the_same_comp(self): data = building_blocks.Data('data', tf.int32) self.assertTrue(tree_analysis.trees_equal(data, data))
def test_returns_false_for_lambdas_with_different_parameter_types(self): ref_1 = building_blocks.Reference('a', tf.int32) fn_1 = building_blocks.Lambda(ref_1.name, ref_1.type_signature, ref_1) ref_2 = building_blocks.Reference('a', tf.float32) fn_2 = building_blocks.Lambda(ref_2.name, ref_2.type_signature, ref_2) self.assertFalse(tree_analysis.trees_equal(fn_1, fn_2))
def test_returns_false_for_comps_with_different_types(self): data = building_blocks.Data('data', tf.int32) ref = building_blocks.Reference('a', tf.int32) self.assertFalse(tree_analysis.trees_equal(data, ref)) self.assertFalse(tree_analysis.trees_equal(ref, data))
def test_returns_true_for_lambdas(self): ref_1 = building_blocks.Reference('a', tf.int32) fn_1 = building_blocks.Lambda(ref_1.name, ref_1.type_signature, ref_1) ref_2 = building_blocks.Reference('a', tf.int32) fn_2 = building_blocks.Lambda(ref_2.name, ref_2.type_signature, ref_2) self.assertTrue(tree_analysis.trees_equal(fn_1, fn_2))
def test_returns_false_for_blocks_with_different_results(self): data_1 = building_blocks.Data('data', tf.int32) comp_1 = building_blocks.Block([], data_1) data_2 = building_blocks.Data('data', tf.float32) comp_2 = building_blocks.Block([], data_2) self.assertFalse(tree_analysis.trees_equal(comp_1, comp_2))
def test_returns_true_for_placements(self): placement_1 = building_blocks.Placement(placements.CLIENTS) placement_2 = building_blocks.Placement(placements.CLIENTS) self.assertTrue(tree_analysis.trees_equal(placement_1, placement_2))
def test_returns_true_for_blocks_with_different_variable_names(self): data = building_blocks.Data('data', tf.int32) comp_1 = building_blocks.Block([('a', data)], data) comp_2 = building_blocks.Block([('b', data)], data) self.assertTrue(tree_analysis.trees_equal(comp_1, comp_2))
def test_returns_true_for_references(self): reference_1 = building_blocks.Reference('a', tf.int32) reference_2 = building_blocks.Reference('a', tf.int32) self.assertTrue(tree_analysis.trees_equal(reference_1, reference_2))
def test_returns_true_for_blocks(self): data_1 = building_blocks.Data('data', tf.int32) comp_1 = building_blocks.Block([('a', data_1)], data_1) data_2 = building_blocks.Data('data', tf.int32) comp_2 = building_blocks.Block([('a', data_2)], data_2) self.assertTrue(tree_analysis.trees_equal(comp_1, comp_2))
def test_returns_false_for_selections_with_differet_names(self): ref_1 = building_blocks.Reference('a', [('a', tf.int32), ('b', tf.int32)]) selection_1 = building_blocks.Selection(ref_1, name='a') ref_2 = building_blocks.Reference('a', [('a', tf.int32), ('b', tf.int32)]) selection_2 = building_blocks.Selection(ref_2, name='b') self.assertFalse(tree_analysis.trees_equal(selection_1, selection_2))
def test_returns_true_for_identical_graphs_with_nans(self): tf_comp1 = _create_tensorflow_graph_with_nan() tf_comp2 = _create_tensorflow_graph_with_nan() self.assertTrue(tree_analysis.trees_equal(tf_comp1, tf_comp2))