def test_prune_does_not_change_exeuction(self): proto = _create_proto_with_unnecessary_op() reduced_proto = tensorflow_computation_transformations.prune_tensorflow_proto( proto) for k in range(5): self.assertEqual(test_utils.run_tensorflow(proto, k), test_utils.run_tensorflow(reduced_proto, k))
def test_executes_correctly_with_tuple_in_result(self): ref_to_int = building_blocks.Reference('var', tf.int32) first_tf_id_type = computation_types.TensorType(tf.int32) first_tf_id = building_block_factory.create_compiled_identity( first_tf_id_type) called_tf_id = building_blocks.Call(first_tf_id, ref_to_int) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) second_tf_id_type = computation_types.TensorType(tf.int32) second_tf_id = building_block_factory.create_compiled_identity( second_tf_id_type) second_called = building_blocks.Call(second_tf_id, ref_to_call) ref_to_second_call = building_blocks.Reference( 'second_call', called_tf_id.type_signature) block_locals = [('call', called_tf_id), ('second_call', second_called)] block = building_blocks.Block( block_locals, building_blocks.Tuple([ref_to_second_call, ref_to_second_call])) tf_representing_block, _ = transformations.create_tensorflow_representing_block( block) result_ones = test_utils.run_tensorflow( tf_representing_block.function.proto, 1) self.assertAllEqual(result_ones, [1, 1]) result_zeros = test_utils.run_tensorflow( tf_representing_block.function.proto, 0) self.assertAllEqual(result_zeros, [0, 0])
def test_replaces_lambda_to_called_graph_on_tuple_of_selections_from_arg_with_tf_of_same_type_with_names( self): identity_tf_block_type = computation_types.StructType( [tf.int32, tf.bool]) identity_tf_block = building_block_factory.create_compiled_identity( identity_tf_block_type) tuple_ref = building_blocks.Reference('x', [('a', tf.int32), ('b', tf.float32), ('c', tf.bool)]) selected_int = building_blocks.Selection(tuple_ref, index=0) selected_bool = building_blocks.Selection(tuple_ref, index=2) created_tuple = building_blocks.Struct([selected_int, selected_bool]) called_tf_block = building_blocks.Call(identity_tf_block, created_tuple) lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32), ('b', tf.float32), ('c', tf.bool)], called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) result = test_utils.run_tensorflow(parsed.proto, { 'a': 9, 'b': 10.0, 'c': False, }) self.assertEqual(structure.Struct([(None, 9), (None, False)]), result)
def test_returned_tensorflow_executes_correctly_with_no_unbound_refs(self): concrete_int_type = computation_types.TensorType(tf.int32) concrete_int = building_block_factory.create_tensorflow_constant( concrete_int_type, 1) first_tf_id_type = computation_types.TensorType(tf.int32) first_tf_id = building_block_factory.create_compiled_identity( first_tf_id_type) called_tf_id = building_blocks.Call(first_tf_id, concrete_int) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) second_tf_id_type = computation_types.TensorType(tf.int32) second_tf_id = building_block_factory.create_compiled_identity( second_tf_id_type) second_called = building_blocks.Call(second_tf_id, ref_to_call) ref_to_second_call = building_blocks.Reference( 'second_call', called_tf_id.type_signature) block_locals = [('call', called_tf_id), ('second_call', second_called)] block = building_blocks.Block( block_locals, building_blocks.Tuple([ref_to_second_call, ref_to_second_call])) tf_representing_block, _ = transformations.create_tensorflow_representing_block( block) result = test_utils.run_tensorflow( tf_representing_block.function.proto) self.assertAllEqual(result, [1, 1])
def test_replaces_lambda_to_named_tuple_of_called_graphs_with_tf_of_same_type( self): int_tensor_type = computation_types.TensorType(tf.int32) int_identity_tf_block = building_block_factory.create_compiled_identity( int_tensor_type) float_tensor_type = computation_types.TensorType(tf.float32) float_identity_tf_block = building_block_factory.create_compiled_identity( float_tensor_type) tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32]) selected_int = building_blocks.Selection(tuple_ref, index=0) selected_float = building_blocks.Selection(tuple_ref, index=1) called_int_tf_block = building_blocks.Call(int_identity_tf_block, selected_int) called_float_tf_block = building_blocks.Call(float_identity_tf_block, selected_float) tuple_of_called_graphs = building_blocks.Struct([ ('a', called_int_tf_block), ('b', called_float_tf_block) ]) lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32], tuple_of_called_graphs) parsed, modified = parse_tff_to_tf(lambda_wrapper) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to( lambda_wrapper.type_signature) result = test_utils.run_tensorflow(parsed.proto, [13, 14.0]) self.assertEqual(structure.Struct([('a', 13), ('b', 14.0)]), result)
def test_replaces_lambda_to_called_composition_of_tf_blocks_with_tf_of_same_type_named_param( self): selection_type = computation_types.StructType([('a', tf.int32), ('b', tf.float32)]) selection_tf_block = _create_compiled_computation( lambda x: x[0], selection_type) add_one_int_type = computation_types.TensorType(tf.int32) add_one_int_tf_block = _create_compiled_computation( lambda x: x + 1, add_one_int_type) int_ref = building_blocks.Reference('x', [('a', tf.int32), ('b', tf.float32)]) called_selection = building_blocks.Call(selection_tf_block, int_ref) one_added = building_blocks.Call(add_one_int_tf_block, called_selection) lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32), ('b', tf.float32)], one_added) parsed, modified = parse_tff_to_tf(lambda_wrapper) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to( lambda_wrapper.type_signature) result = test_utils.run_tensorflow(parsed.proto, {'a': 15, 'b': 16.0}) self.assertEqual(16.0, result)
def test_replaces_lambda_to_called_graph_on_tuple_of_selections_from_arg_with_tf_of_same_type( self): identity_tf_block_type = computation_types.StructType( [tf.int32, tf.bool]) identity_tf_block = building_block_factory.create_compiled_identity( identity_tf_block_type) tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32, tf.bool]) selected_int = building_blocks.Selection(tuple_ref, index=0) selected_bool = building_blocks.Selection(tuple_ref, index=2) created_tuple = building_blocks.Struct([selected_int, selected_bool]) called_tf_block = building_blocks.Call(identity_tf_block, created_tuple) lambda_wrapper = building_blocks.Lambda( 'x', [tf.int32, tf.float32, tf.bool], called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to( lambda_wrapper.type_signature) result = test_utils.run_tensorflow(parsed.proto, [7, 8.0, True]) self.assertEqual(structure.Struct([(None, 7), (None, True)]), result)
def test_returns_computation(self, py_fn, type_signature, arg, expected_result): proto, _ = tensorflow_computation_factory.create_computation_for_py_fn( py_fn, type_signature) self.assertIsInstance(proto, pb.Computation) actual_result = test_utils.run_tensorflow(proto, arg) self.assertEqual(actual_result, expected_result)
def test_returns_computation(self, type_signature, value): proto = tensorflow_computation_factory.create_identity(type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = type_factory.unary_op(type_signature) self.assertEqual(actual_type, expected_type) actual_result = test_utils.run_tensorflow(proto, value) self.assertEqual(actual_result, value)
def test_returns_coputation(self): proto = tensorflow_computation_factory.create_empty_tuple() self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, []) self.assertEqual(actual_type, expected_type) actual_result = test_utils.run_tensorflow(proto) expected_result = anonymous_tuple.AnonymousTuple([]) self.assertEqual(actual_result, expected_result)
def test_returns_computation(self): proto, _ = tensorflow_computation_factory.create_empty_tuple() self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, []) expected_type.check_assignable_from(actual_type) actual_result = test_utils.run_tensorflow(proto) expected_result = structure.Struct([]) self.assertEqual(actual_result, expected_result)
def test_executes_correctly_after_resolving_multiple_variables(self): ref_to_int = building_blocks.Reference('var', tf.int32) first_tf_id = building_block_factory.create_compiled_identity(tf.int32) called_tf_id = building_blocks.Call(first_tf_id, ref_to_int) ref_to_call = building_blocks.Reference('call', called_tf_id.type_signature) second_tf_id = building_block_factory.create_compiled_identity(tf.int32) second_called = building_blocks.Call(second_tf_id, ref_to_call) ref_to_second_call = building_blocks.Reference('second_call', called_tf_id.type_signature) block_locals = [('call', called_tf_id), ('second_call', second_called)] block = building_blocks.Block(block_locals, ref_to_second_call) tf_representing_block, _ = compiler_transformations.create_tensorflow_representing_block( block) result_one = test_utils.run_tensorflow(tf_representing_block.function.proto, 1) self.assertEqual(result_one, 1) result_zero = test_utils.run_tensorflow( tf_representing_block.function.proto, 0) self.assertEqual(result_zero, 0)
def test_generated_tensorflow_executes_correctly_sequence_parameter(self): param = building_blocks.Reference('x', computation_types.SequenceType(tf.int32)) body = building_blocks.Tuple([param]) sequence_ref = building_blocks.Reference( 'y', computation_types.SequenceType(tf.int32)) tf_block = compiler_transformations.construct_tensorflow_calling_lambda_on_concrete_arg( param, body, sequence_ref) result = test_utils.run_tensorflow(tf_block.function.proto, list(range(5))) self.assertLen(result, 1) self.assertAllEqual(result[0], list(range(5)))
def test_returns_computation(self, type_signature, count, value): proto, _ = tensorflow_computation_factory.create_replicate_input( type_signature, count) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType( type_signature, [type_signature] * count) expected_type.check_assignable_from(actual_type) actual_result = test_utils.run_tensorflow(proto, value) expected_result = structure.Struct([(None, value)] * count) self.assertEqual(actual_result, expected_result)
def test_returns_computation_sequence(self): type_signature = computation_types.SequenceType(tf.int32) proto = tensorflow_computation_factory.create_identity(type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = type_factory.unary_op(type_signature) self.assertEqual(actual_type, expected_type) expected_value = [10] * 3 actual_value = test_utils.run_tensorflow(proto, expected_value) self.assertEqual(actual_value, expected_value)
def test_returns_computation(self, type_signature, count, value): proto = tensorflow_computation_factory.create_replicate_input( type_signature, count) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType( type_signature, [type_signature] * count) self.assertEqual(actual_type, expected_type) actual_result = test_utils.run_tensorflow(proto, value) expected_result = anonymous_tuple.AnonymousTuple([(None, value)] * count) self.assertEqual(actual_result, expected_result)
def test_returns_computation_with_tuple_unnamed(self): value = 10 type_signature = computation_types.NamedTupleType([tf.int32] * 3) proto = tensorflow_computation_factory.create_constant( value, type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, type_signature) self.assertEqual(actual_type, expected_type) expected_value = [value] * 3 actual_value = test_utils.run_tensorflow(proto, expected_value) self.assertCountEqual(actual_value, expected_value)
def test_returns_computation_with_tensor_float(self): value = 10.0 type_signature = computation_types.TensorType(tf.float32, [3]) proto = tensorflow_computation_factory.create_constant( value, type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, type_signature) self.assertEqual(actual_type, expected_type) expected_value = [value] * 3 actual_value = test_utils.run_tensorflow(proto, expected_value) self.assertCountEqual(actual_value, expected_value)
def test_returns_computation_tuple_named(self): type_signature = [('a', tf.int32), ('b', tf.float32)] proto = tensorflow_computation_factory.create_identity(type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = type_factory.unary_op(type_signature) self.assertEqual(actual_type, expected_type) expected_value = anonymous_tuple.AnonymousTuple([('a', 10), ('b', 10.0)]) actual_value = test_utils.run_tensorflow(proto, expected_value) self.assertEqual(actual_value, expected_value)
def test_returns_computation(self, value, type_signature, expected_result): proto, _ = tensorflow_computation_factory.create_constant( value, type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, type_signature) expected_type.check_assignable_from(actual_type) actual_result = test_utils.run_tensorflow(proto) if isinstance(expected_result, list): self.assertCountEqual(actual_result, expected_result) else: self.assertEqual(actual_result, expected_result)
def test_generated_tensorflow_executes_correctly_int_parameter(self): param = building_blocks.Reference('x', tf.int32) body = building_blocks.Tuple([ building_blocks.Reference('x', tf.int32), building_blocks.Reference('x', tf.int32) ]) int_constant = building_block_factory.create_tensorflow_constant( tf.int32, 0) tf_block = transformations.construct_tensorflow_calling_lambda_on_concrete_arg( param, body, int_constant) result = test_utils.run_tensorflow(tf_block.function.proto) self.assertLen(result, 2) self.assertEqual(result[0], 0) self.assertEqual(result[1], 0)
def test_generated_tensorflow_executes_correctly_tuple_parameter(self): param = building_blocks.Reference('x', [tf.int32, tf.float32]) body = building_blocks.Tuple([ building_blocks.Selection(param, index=1), building_blocks.Selection(param, index=0) ]) int_constant = building_block_factory.create_tensorflow_constant( [tf.int32, tf.float32], 1) tf_block = compiler_transformations.construct_tensorflow_calling_lambda_on_concrete_arg( param, body, int_constant) result = test_utils.run_tensorflow(tf_block.function.proto) self.assertLen(result, 2) self.assertEqual(result[0], 1.) self.assertEqual(result[1], 1)
def test_returns_computation(self, operator, type_signature, operands, expected_result): proto, _ = tensorflow_computation_factory.create_binary_operator_with_upcast( type_signature, operator) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) self.assertIsInstance(actual_type, computation_types.FunctionType) # Note: It is only useful to test the parameter type; the result type # depends on the `operator` used, not the implemenation # `create_binary_operator_with_upcast`. expected_parameter_type = computation_types.StructType(type_signature) self.assertEqual(actual_type.parameter, expected_parameter_type) actual_result = test_utils.run_tensorflow(proto, operands) self.assertEqual(actual_result, expected_result)
def test_returns_computation(self, type_signature, shape, value, expected_result): proto = tensorflow_computation_factory.create_broadcast_scalar_to_shape( type_signature, shape) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) result_type = computation_types.TensorType(type_signature, shape=shape) expected_type = computation_types.FunctionType(type_signature, result_type) self.assertEqual(actual_type, expected_type) actual_result = test_utils.run_tensorflow(proto, value) if isinstance(expected_result, list): self.assertCountEqual(actual_result, expected_result) else: self.assertEqual(actual_result, expected_result)
def test_replaces_lambda_to_called_graph_with_tf_of_same_type(self): identity_tf_block_type = computation_types.TensorType(tf.int32) identity_tf_block = building_block_factory.create_compiled_identity( identity_tf_block_type) int_ref = building_blocks.Reference('x', tf.int32) called_tf_block = building_blocks.Call(identity_tf_block, int_ref) lambda_wrapper = building_blocks.Lambda('x', tf.int32, called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to( lambda_wrapper.type_signature) result = test_utils.run_tensorflow(parsed.proto, 2) self.assertEqual(2, result)
def test_replaces_lambda_to_called_graph_on_selection_from_arg_with_tf_of_same_type_with_names( self): identity_tf_block_type = computation_types.TensorType(tf.int32) identity_tf_block = building_block_factory.create_compiled_identity( identity_tf_block_type) tuple_ref = building_blocks.Reference('x', [('a', tf.int32), ('b', tf.float32)]) selected_int = building_blocks.Selection(tuple_ref, index=0) called_tf_block = building_blocks.Call(identity_tf_block, selected_int) lambda_wrapper = building_blocks.Lambda('x', [('a', tf.int32), ('b', tf.float32)], called_tf_block) parsed, modified = parse_tff_to_tf(lambda_wrapper) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) self.assertEqual(parsed.type_signature, lambda_wrapper.type_signature) result = test_utils.run_tensorflow(parsed.proto, {'a': 5, 'b': 6.0}) self.assertEqual(5, result)
def test_replaces_lambda_to_called_tf_block_with_replicated_lambda_arg_with_tf_block_of_same_type( self): sum_and_add_one_type = computation_types.StructType( [tf.int32, tf.int32]) sum_and_add_one = _create_compiled_computation( lambda x: x[0] + x[1] + 1, sum_and_add_one_type) int_ref = building_blocks.Reference('x', tf.int32) tuple_of_ints = building_blocks.Struct((int_ref, int_ref)) summed = building_blocks.Call(sum_and_add_one, tuple_of_ints) lambda_wrapper = building_blocks.Lambda('x', tf.int32, summed) parsed, modified = parse_tff_to_tf(lambda_wrapper) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to( lambda_wrapper.type_signature) result = test_utils.run_tensorflow(parsed.proto, 17) self.assertEqual(35, result)
def test_returns_computation(self, operator, type_signature, operands, expected_result): # TODO(b/142795960): arguments in parameterized are called before test main. # `tf.constant` will error out on GPU and TPU without proper initialization. # A suggested workaround is to use numpy as argument and transform to TF # tensor inside the function. operands = tf.nest.map_structure(tf.constant, operands) proto, _ = tensorflow_computation_factory.create_binary_operator_with_upcast( type_signature, operator) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) self.assertIsInstance(actual_type, computation_types.FunctionType) # Note: It is only useful to test the parameter type; the result type # depends on the `operator` used, not the implemenation # `create_binary_operator_with_upcast`. expected_parameter_type = computation_types.StructType(type_signature) self.assertEqual(actual_type.parameter, expected_parameter_type) actual_result = test_utils.run_tensorflow(proto, operands) self.assertEqual(actual_result, expected_result)
def test_replaces_lambda_to_selection_from_called_graph_with_tf_of_same_type( self): identity_tf_block_type = computation_types.StructType( [tf.int32, tf.float32]) identity_tf_block = building_block_factory.create_compiled_identity( identity_tf_block_type) tuple_ref = building_blocks.Reference('x', [tf.int32, tf.float32]) called_tf_block = building_blocks.Call(identity_tf_block, tuple_ref) selection_from_call = building_blocks.Selection(called_tf_block, index=1) lambda_wrapper = building_blocks.Lambda('x', [tf.int32, tf.float32], selection_from_call) parsed, modified = parse_tff_to_tf(lambda_wrapper) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to( lambda_wrapper.type_signature) result = test_utils.run_tensorflow(parsed.proto, [0, 1.0]) self.assertEqual(1.0, result)