def test_returns_single_called_graph_after_resolving_multiple_variables( self): ref_to_int = building_blocks.Reference('var', tf.int32) first_tf_id_type = computation_types.TensorType(tf.int32) first_tf_id = building_block_factory.create_compiled_identity( first_tf_id_type) called_tf_id = building_blocks.Call(first_tf_id, ref_to_int) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) second_tf_id_type = computation_types.TensorType(tf.int32) second_tf_id = building_block_factory.create_compiled_identity( second_tf_id_type) second_called = building_blocks.Call(second_tf_id, ref_to_call) ref_to_second_call = building_blocks.Reference( 'second_call', called_tf_id.type_signature) block_locals = [('call', called_tf_id), ('second_call', second_called)] block = building_blocks.Block(block_locals, ref_to_second_call) tf_representing_block, _ = transformations.create_tensorflow_representing_block( block) self.assertEqual(tf_representing_block.type_signature, block.type_signature) self.assertIsInstance(tf_representing_block, building_blocks.Call) self.assertIsInstance(tf_representing_block.function, building_blocks.CompiledComputation) self.assertIsInstance(tf_representing_block.argument, building_blocks.Reference) self.assertEqual(tf_representing_block.argument.name, 'var')
def test_replaces_lambda_to_named_tuple_of_called_graphs_with_tf_of_same_type( self): int_identity_tf_block = building_block_factory.create_compiled_identity( tf.int32) float_identity_tf_block = building_block_factory.create_compiled_identity( tf.float32) tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32]) selected_int = building_blocks.Selection(tuple_ref, index=0) selected_float = building_blocks.Selection(tuple_ref, index=1) called_int_tf_block = building_blocks.Call(int_identity_tf_block, selected_int) called_float_tf_block = building_blocks.Call(float_identity_tf_block, selected_float) tuple_of_called_graphs = building_blocks.Tuple([ ('a', called_int_tf_block), ('b', called_float_tf_block) ]) lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32], tuple_of_called_graphs) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) self.assertEqual(exec_lambda([13, 14.]), exec_tf([13, 14.]))
def test_returned_tensorflow_executes_correctly_with_no_unbound_refs(self): concrete_int_type = computation_types.TensorType(tf.int32) concrete_int = building_block_factory.create_tensorflow_constant( concrete_int_type, 1) first_tf_id_type = computation_types.TensorType(tf.int32) first_tf_id = building_block_factory.create_compiled_identity( first_tf_id_type) called_tf_id = building_blocks.Call(first_tf_id, concrete_int) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) second_tf_id_type = computation_types.TensorType(tf.int32) second_tf_id = building_block_factory.create_compiled_identity( second_tf_id_type) second_called = building_blocks.Call(second_tf_id, ref_to_call) ref_to_second_call = building_blocks.Reference( 'second_call', called_tf_id.type_signature) block_locals = [('call', called_tf_id), ('second_call', second_called)] block = building_blocks.Block( block_locals, building_blocks.Tuple([ref_to_second_call, ref_to_second_call])) tf_representing_block, _ = transformations.create_tensorflow_representing_block( block) result = test_utils.run_tensorflow( tf_representing_block.function.proto) self.assertAllEqual(result, [1, 1])
def test_replaces_lambda_to_named_tuple_of_called_graphs_with_tf_of_same_type( self): int_tensor_type = computation_types.TensorType(tf.int32) int_identity_tf_block = building_block_factory.create_compiled_identity( int_tensor_type) float_tensor_type = computation_types.TensorType(tf.float32) float_identity_tf_block = building_block_factory.create_compiled_identity( float_tensor_type) tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32]) selected_int = building_blocks.Selection(tuple_ref, index=0) selected_float = building_blocks.Selection(tuple_ref, index=1) called_int_tf_block = building_blocks.Call(int_identity_tf_block, selected_int) called_float_tf_block = building_blocks.Call(float_identity_tf_block, selected_float) tuple_of_called_graphs = building_blocks.Struct([ ('a', called_int_tf_block), ('b', called_float_tf_block) ]) lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32], tuple_of_called_graphs) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to( lambda_wrapper.type_signature) self.assertEqual(exec_lambda([13, 14.]), exec_tf([13, 14.]))
def test_executes_correctly_with_tuple_in_result(self): ref_to_int = building_blocks.Reference('var', tf.int32) first_tf_id_type = computation_types.TensorType(tf.int32) first_tf_id = building_block_factory.create_compiled_identity( first_tf_id_type) called_tf_id = building_blocks.Call(first_tf_id, ref_to_int) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) second_tf_id_type = computation_types.TensorType(tf.int32) second_tf_id = building_block_factory.create_compiled_identity( second_tf_id_type) second_called = building_blocks.Call(second_tf_id, ref_to_call) ref_to_second_call = building_blocks.Reference( 'second_call', called_tf_id.type_signature) block_locals = [('call', called_tf_id), ('second_call', second_called)] block = building_blocks.Block( block_locals, building_blocks.Tuple([ref_to_second_call, ref_to_second_call])) tf_representing_block, _ = transformations.create_tensorflow_representing_block( block) result_ones = test_utils.run_tensorflow( tf_representing_block.function.proto, 1) self.assertAllEqual(result_ones, [1, 1]) result_zeros = test_utils.run_tensorflow( tf_representing_block.function.proto, 0) self.assertAllEqual(result_zeros, [0, 0])
def test_returns_correct_structure_with_no_unbound_references(self): concrete_int_type = computation_types.TensorType(tf.int32) concrete_int = building_block_factory.create_tensorflow_constant( concrete_int_type, 1) first_tf_id_type = computation_types.TensorType(tf.int32) first_tf_id = building_block_factory.create_compiled_identity( first_tf_id_type) called_tf_id = building_blocks.Call(first_tf_id, concrete_int) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) second_tf_id_type = computation_types.TensorType(tf.int32) second_tf_id = building_block_factory.create_compiled_identity( second_tf_id_type) second_called = building_blocks.Call(second_tf_id, ref_to_call) ref_to_second_call = building_blocks.Reference( 'second_call', called_tf_id.type_signature) block_locals = [('call', called_tf_id), ('second_call', second_called)] block = building_blocks.Block( block_locals, building_blocks.Tuple([ref_to_second_call, ref_to_second_call])) tf_representing_block, _ = transformations.create_tensorflow_representing_block( block) self.assertEqual(tf_representing_block.type_signature, block.type_signature) self.assertIsInstance(tf_representing_block, building_blocks.Call) self.assertIsInstance(tf_representing_block.function, building_blocks.CompiledComputation) self.assertIsNone(tf_representing_block.argument)
def test_returns_true_for_compiled_computations_with_different_names(self): tensor_type = computation_types.TensorType(tf.int32) compiled_1 = building_block_factory.create_compiled_identity( tensor_type, 'a') compiled_2 = building_block_factory.create_compiled_identity( tensor_type, 'b') self.assertTrue(tree_analysis.trees_equal(compiled_1, compiled_2))
def test_returns_false_for_compiled_computations_with_different_types( self): compiled_1 = building_block_factory.create_compiled_identity( tf.int32, 'a') compiled_2 = building_block_factory.create_compiled_identity( tf.float32, 'a') self.assertFalse(tree_analysis.trees_equal(compiled_1, compiled_2))
def test_returns_correct_structure_with_tuple_in_result(self): ref_to_int = building_blocks.Reference('var', tf.int32) first_tf_id = building_block_factory.create_compiled_identity(tf.int32) called_tf_id = building_blocks.Call(first_tf_id, ref_to_int) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) second_tf_id = building_block_factory.create_compiled_identity( tf.int32) second_called = building_blocks.Call(second_tf_id, ref_to_call) ref_to_second_call = building_blocks.Reference( 'second_call', called_tf_id.type_signature) block_locals = [('call', called_tf_id), ('second_call', second_called)] block = building_blocks.Block( block_locals, building_blocks.Tuple([ref_to_second_call, ref_to_second_call])) tf_representing_block, _ = transformations.create_tensorflow_representing_block( block) self.assertEqual(tf_representing_block.type_signature, block.type_signature) self.assertIsInstance(tf_representing_block, building_blocks.Call) self.assertIsInstance(tf_representing_block.function, building_blocks.CompiledComputation) self.assertIsInstance(tf_representing_block.argument, building_blocks.Reference) self.assertEqual(tf_representing_block.argument.name, 'var')
def test_basic_functionality_of_compiled_computation_class(self): x = building_block_factory.create_compiled_identity(tf.int32, 'a') self.assertEqual(x.type_signature.compact_representation(), '(int32 -> int32)') self.assertIsInstance(x.proto, pb.Computation) self.assertEqual(x.name, 'a') self.assertTrue( repr(x), 'CompiledComputation(\'a\', FunctionType(TensorType(tf.int32), TensorType(tf.int32)))' ) self.assertTrue(x.compact_representation(), 'comp#a') y = building_block_factory.create_compiled_identity(tf.int32) self._serialize_deserialize_roundtrip_test(y)
def test_should_transform_tf_computation(self): tuple_type = computation_types.TensorType(tf.int32) compiled_computation = building_block_factory.create_compiled_identity( tuple_type) self.assertTrue( compiled_computation_transformations.RaiseOnDisallowedOp( frozenset()).should_transform(compiled_computation))
def test_single_tensorflow_node_count_agrees_with_node_count(self): integer_identity = building_block_factory.create_compiled_identity(tf.int32) node_tf_op_count = building_block_analysis.count_tensorflow_ops_in( integer_identity) tree_tf_op_count = tree_analysis.count_tensorflow_ops_under( integer_identity) self.assertEqual(node_tf_op_count, tree_tf_op_count)
def test_replaces_lambda_to_called_graph_on_tuple_of_selections_from_arg_with_tf_of_same_type( self): identity_tf_block = building_block_factory.create_compiled_identity( [tf.int32, tf.bool]) tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32, tf.bool]) selected_int = building_blocks.Selection(tuple_ref, index=0) selected_bool = building_blocks.Selection(tuple_ref, index=2) created_tuple = building_blocks.Tuple([selected_int, selected_bool]) called_tf_block = building_blocks.Call(identity_tf_block, created_tuple) lambda_wrapper = building_blocks.Lambda( 'x', [tf.int32, tf.float32, tf.bool], called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertEqual(exec_lambda([7, 8., True]), exec_tf([7, 8., True]))
def test_raises_with_naked_graph_as_block_local(self): graph = building_block_factory.create_compiled_identity(tf.int32) block_locals = [('graph', graph)] ref_to_graph = building_blocks.Reference('graph', graph.type_signature) block = building_blocks.Block(block_locals, ref_to_graph) with self.assertRaises(ValueError): compiler_transformations.create_tensorflow_representing_block(block)
def test_returns_string_for_compiled_computation(self): tensor_type = computation_types.TensorType(tf.int32) comp = building_block_factory.create_compiled_identity(tensor_type, 'a') self.assertEqual(comp.compact_representation(), 'comp#a') self.assertEqual(comp.formatted_representation(), 'comp#a') self.assertEqual(comp.structural_representation(), 'Compiled(a)')
def test_replaces_lambda_to_called_graph_on_tuple_of_selections_from_arg_with_tf_of_same_type_with_names( self): identity_tf_block_type = computation_types.StructType( [tf.int32, tf.bool]) identity_tf_block = building_block_factory.create_compiled_identity( identity_tf_block_type) tuple_ref = building_blocks.Reference('x', [('a', tf.int32), ('b', tf.float32), ('c', tf.bool)]) selected_int = building_blocks.Selection(tuple_ref, index=0) selected_bool = building_blocks.Selection(tuple_ref, index=2) created_tuple = building_blocks.Struct([selected_int, selected_bool]) called_tf_block = building_blocks.Call(identity_tf_block, created_tuple) lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32), ('b', tf.float32), ('c', tf.bool)], called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) result = test_utils.run_tensorflow(parsed.proto, { 'a': 9, 'b': 10.0, 'c': False, }) self.assertEqual(structure.Struct([(None, 9), (None, False)]), result)
def federated_reduce(arg): x = arg[0] zero = arg[1] op = arg[2] identity = building_block_factory.create_compiled_identity( op.type_signature.result) return intrinsics.federated_aggregate(x, zero, op, op, identity)
def test_tensorflow_op_count_doubles_number_of_ops_in_two_tuple(self): integer_identity = building_block_factory.create_compiled_identity(tf.int32) node_tf_op_count = building_block_analysis.count_tensorflow_ops_in( integer_identity) tf_tuple = building_blocks.Tuple([integer_identity, integer_identity]) tree_tf_op_count = tree_analysis.count_tensorflow_ops_under(tf_tuple) self.assertEqual(tree_tf_op_count, 2 * node_tf_op_count)
def test_replaces_lambda_to_called_graph_on_tuple_of_selections_from_arg_with_tf_of_same_type( self): identity_tf_block_type = computation_types.StructType( [tf.int32, tf.bool]) identity_tf_block = building_block_factory.create_compiled_identity( identity_tf_block_type) tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32, tf.bool]) selected_int = building_blocks.Selection(tuple_ref, index=0) selected_bool = building_blocks.Selection(tuple_ref, index=2) created_tuple = building_blocks.Struct([selected_int, selected_bool]) called_tf_block = building_blocks.Call(identity_tf_block, created_tuple) lambda_wrapper = building_blocks.Lambda( 'x', [tf.int32, tf.float32, tf.bool], called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to( lambda_wrapper.type_signature) result = test_utils.run_tensorflow(parsed.proto, [7, 8.0, True]) self.assertEqual(structure.Struct([(None, 7), (None, True)]), result)
def test_replaces_lambda_to_called_graph_on_selection_from_arg_with_tf_of_same_type_with_names( self): identity_tf_block = building_block_factory.create_compiled_identity( tf.int32) tuple_ref = building_blocks.Reference('x', [('a', tf.int32), ('b', tf.float32)]) selected_int = building_blocks.Selection(tuple_ref, index=0) called_tf_block = building_blocks.Call(identity_tf_block, selected_int) lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32), ('b', tf.float32)], called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) self.assertEqual(exec_lambda({ 'a': 5, 'b': 6. }), exec_tf({ 'a': 5, 'b': 6. }))
def test_does_not_reduce_no_unnecessary_ops(self): comp = building_block_factory.create_compiled_identity(tf.int32) pruned = building_blocks.CompiledComputation( proto_transformations.prune_tensorflow_proto(comp.proto)) ops_before = building_block_analysis.count_tensorflow_ops_in(comp) ops_after = building_block_analysis.count_tensorflow_ops_in(pruned) self.assertEqual(ops_before, ops_after)
def test_replaces_lambda_to_selection_from_called_graph_with_tf_of_same_type( self): identity_tf_block_type = computation_types.StructType( [tf.int32, tf.float32]) identity_tf_block = building_block_factory.create_compiled_identity( identity_tf_block_type) tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32]) called_tf_block = building_blocks.Call(identity_tf_block, tuple_ref) selection_from_call = building_blocks.Selection(called_tf_block, index=1) lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32], selection_from_call) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to( lambda_wrapper.type_signature) self.assertEqual(exec_lambda([0, 1.]), exec_tf([0, 1.]))
def _create_complex_computation(): tensor_type = computation_types.TensorType(tf.int32) compiled = building_block_factory.create_compiled_identity( tensor_type, 'a') federated_type = computation_types.FederatedType(tf.int32, placements.SERVER) arg_ref = building_blocks.Reference('arg', federated_type) bindings = [] results = [] def _bind(name, value): bindings.append((name, value)) return building_blocks.Reference(name, value.type_signature) for i in range(2): called_federated_broadcast = building_block_factory.create_federated_broadcast( arg_ref) called_federated_map = building_block_factory.create_federated_map( compiled, _bind(f'broadcast_{i}', called_federated_broadcast)) called_federated_mean = building_block_factory.create_federated_mean( _bind(f'map_{i}', called_federated_map), None) results.append(_bind(f'mean_{i}', called_federated_mean)) result = building_blocks.Struct(results) block = building_blocks.Block(bindings, result) return building_blocks.Lambda('arg', tf.int32, block)
def test_transform_compiled_computation_returns_compiled_computation_with_id( self): tuple_type = computation_types.TensorType(tf.int32) compiled_computation = building_block_factory.create_compiled_identity( tuple_type) add_ids = compiled_computation_transformations.AddUniqueIDs() with self.subTest('first_comp_non_zero_id'): first_transformed_comp, mutated = add_ids.transform( compiled_computation) self.assertTrue(mutated) self.assertIsInstance(first_transformed_comp, building_blocks.CompiledComputation) self.assertTrue( first_transformed_comp.proto.tensorflow.HasField('cache_key')) self.assertNotEqual( first_transformed_comp.proto.tensorflow.cache_key.id, 0) with self.subTest('second_comp_same_id'): second_transformed_comp, mutated = add_ids.transform( compiled_computation) self.assertTrue(mutated) self.assertIsInstance(second_transformed_comp, building_blocks.CompiledComputation) self.assertTrue( second_transformed_comp.proto.tensorflow.HasField('cache_key')) self.assertNotEqual( second_transformed_comp.proto.tensorflow.cache_key.id, 0) self.assertEqual( first_transformed_comp.proto.tensorflow.cache_key.id, second_transformed_comp.proto.tensorflow.cache_key.id) with self.subTest('restart_transformation_same_id'): # Test that the sequence ids are the same if we run a new compiler pass. # With compiler running inside the `invoke` call, we need to ensure # running different computations doesn't produce the same ids. add_ids = compiled_computation_transformations.AddUniqueIDs() third_transformed_comp, mutated = add_ids.transform( compiled_computation) self.assertTrue(mutated) self.assertTrue( third_transformed_comp.proto.tensorflow.HasField('cache_key')) self.assertNotEqual( third_transformed_comp.proto.tensorflow.cache_key.id, 0) self.assertEqual( first_transformed_comp.proto.tensorflow.cache_key.id, third_transformed_comp.proto.tensorflow.cache_key.id) with self.subTest('different_computation_different_id'): different_compiled_computation = _create_compiled_computation( lambda x: x + tf.constant(1.0), computation_types.TensorType(tf.float32)) different_transformed_comp, mutated = add_ids.transform( different_compiled_computation) self.assertTrue(mutated) self.assertTrue( different_transformed_comp.proto.tensorflow.HasField( 'cache_key')) self.assertNotEqual( different_transformed_comp.proto.tensorflow.cache_key.id, 0) self.assertNotEqual( first_transformed_comp.proto.tensorflow.cache_key.id, different_transformed_comp.proto.tensorflow.cache_key.id)
def test_returns_tf_computation_block_with_compiled_comp(self): concrete_int_type = computation_types.TensorType(tf.int32) tf_identity = building_block_factory.create_compiled_identity( concrete_int_type) unused_int = building_block_factory.create_tensorflow_constant( concrete_int_type, 1) block_to_id = building_blocks.Block([('x', unused_int)], tf_identity) self.assert_compiles_to_tensorflow(block_to_id)
def test_transform_no_disallowed_ops(self): tuple_type = computation_types.TensorType(tf.int32) compiled_computation = building_block_factory.create_compiled_identity( tuple_type) disallowed_op_names = frozenset(['ShardedFilename']) _, mutated = compiled_computation_transformations.RaiseOnDisallowedOp( disallowed_op_names).transform(compiled_computation) self.assertFalse(mutated)
def test_returns_string_for_compiled_computation(self): comp = building_block_factory.create_compiled_identity(tf.int32, 'a') compact_string = comp.compact_representation() self.assertEqual(compact_string, 'comp#a') formatted_string = comp.formatted_representation() self.assertEqual(formatted_string, 'comp#a') structural_string = comp.structural_representation() self.assertEqual(structural_string, 'Compiled(a)')
def test_should_transform_compiled_computation(self): tuple_type = computation_types.TensorType(tf.int32) compiled_computation = building_block_factory.create_compiled_identity( tuple_type) config = tf.compat.v1.ConfigProto() tf_optimizer = compiled_computation_transformations.TensorFlowOptimizer( config) self.assertTrue(tf_optimizer.should_transform(compiled_computation))
def test_transform_only_allowed_ops(self): tuple_type = computation_types.TensorType(tf.int32) compiled_computation = building_block_factory.create_compiled_identity( tuple_type) allowed_op_names = frozenset( ['Const', 'PartitionedCall', 'Identity', 'Placeholder']) _, mutated = compiled_computation_transformations.VerifyAllowedOps( allowed_op_names).transform(compiled_computation) self.assertFalse(mutated)
def test_transform_disallowed_ops(self): tuple_type = computation_types.TensorType(tf.int32) compiled_computation = building_block_factory.create_compiled_identity( tuple_type) disallowed_op_names = frozenset(['Identity']) with self.assertRaises(tensorflow_computation_transformations. DisallowedOpInTensorFlowComputationError): compiled_computation_transformations.RaiseOnDisallowedOp( disallowed_op_names).transform(compiled_computation)