def test_fails_conflicting_binding_in_parameter_and_result(self): t1 = computation_types.FunctionType( computation_types.AbstractType('T'), computation_types.AbstractType('T')) t2 = computation_types.FunctionType(tf.int32, tf.float32) with self.assertRaises(type_analysis.UnassignableConcreteTypesError): type_analysis.check_concrete_instance_of(t2, t1)
def test_succeeds_function_different_parameter_and_return_types(self): t1 = computation_types.FunctionType( computation_types.StructType([ computation_types.AbstractType('U'), computation_types.AbstractType('T') ]), computation_types.AbstractType('T')) t2 = computation_types.FunctionType( computation_types.StructType([tf.int32, tf.float32]), tf.float32) type_analysis.check_concrete_instance_of(t2, t1)
class IsStructureOfIntegersTest(parameterized.TestCase): @parameterized.named_parameters( ('empty_struct', computation_types.StructType([])), ('int', computation_types.TensorType(tf.int32)), ('ints', computation_types.StructType([tf.int32, tf.int32])), ('nested_struct', computation_types.StructType([ computation_types.TensorType(tf.int32), computation_types.StructType([tf.int32, tf.int32]) ])), ('federated_int_at_clients', computation_types.FederatedType(tf.int32, placements.CLIENTS)), ) def test_returns_true(self, type_spec): self.assertTrue(type_analysis.is_structure_of_integers(type_spec)) @parameterized.named_parameters( ('bool', computation_types.TensorType(tf.bool)), ('float', computation_types.TensorType(tf.float32)), ('string', computation_types.TensorType(tf.string)), ('int_and_bool', computation_types.StructType([tf.int32, tf.bool])), ('nested_struct', computation_types.StructType([ computation_types.TensorType(tf.int32), computation_types.StructType([tf.bool, tf.bool]) ])), ('sequence_of_ints', computation_types.SequenceType(tf.int32)), ('placement', computation_types.PlacementType()), ('function', computation_types.FunctionType(tf.int32, tf.int32)), ('abstract', computation_types.AbstractType('T')), ) def test_returns_false(self, type_spec): self.assertFalse(type_analysis.is_structure_of_integers(type_spec))
class IsSumCompatibleTest(parameterized.TestCase): @parameterized.named_parameters([ ('tensor_type', computation_types.TensorType(tf.int32)), ('tuple_type_int', computation_types.StructType([tf.int32, tf.int32], )), ('tuple_type_float', computation_types.StructType([tf.complex128, tf.float32, tf.float64])), ('federated_type', computation_types.FederatedType(tf.int32, placements.CLIENTS)), ]) def test_positive_examples(self, type_spec): type_analysis.check_is_sum_compatible(type_spec) @parameterized.named_parameters([ ('tensor_type_bool', computation_types.TensorType(tf.bool)), ('tensor_type_string', computation_types.TensorType(tf.string)), ('partially_defined_shape', computation_types.TensorType(tf.int32, shape=[None])), ('tuple_type', computation_types.StructType([tf.int32, tf.bool])), ('sequence_type', computation_types.SequenceType(tf.int32)), ('placement_type', computation_types.PlacementType()), ('function_type', computation_types.FunctionType(tf.int32, tf.int32)), ('abstract_type', computation_types.AbstractType('T')), ('ragged_tensor', computation_types.StructWithPythonType([], tf.RaggedTensor)), ('sparse_tensor', computation_types.StructWithPythonType([], tf.SparseTensor)), ]) def test_negative_examples(self, type_spec): with self.assertRaises(type_analysis.SumIncompatibleError): type_analysis.check_is_sum_compatible(type_spec)
class VisitPreorderTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters([ ('abstract_type', computation_types.AbstractType('T'), 1), ('nested_function_type', computation_types.FunctionType( computation_types.FunctionType( computation_types.FunctionType(tf.int32, tf.int32), tf.int32), tf.int32), 7), ('named_tuple_type', computation_types.StructType( [tf.int32, tf.bool, computation_types.SequenceType(tf.int32)]), 5), ('placement_type', computation_types.PlacementType(), 1), ]) # pyformat: enable def test_preorder_call_count(self, type_signature, expected_count): class Counter(object): k = 0 def _count_hits(given_type, arg): del given_type # Unused. Counter.k += 1 return arg type_transformations.visit_preorder(type_signature, _count_hits, None) actual_count = Counter.k self.assertEqual(actual_count, expected_count)
def test_fails_conflicting_concrete_types_under_sequence(self): t1 = self.func_with_param( computation_types.SequenceType( [computation_types.AbstractType('T')] * 2)) t2 = self.func_with_param( computation_types.SequenceType([tf.int32, tf.float32])) with self.assertRaises(type_analysis.MismatchedConcreteTypesError): type_analysis.check_concrete_instance_of(t2, t1)
def _federated_select(client_keys, max_key, server_val, select_fn, secure): """Internal helper for `federated_select` and `federated_secure_select`.""" client_keys = value_impl.to_value(client_keys, None) _check_select_keys_type(client_keys.type_signature, secure) max_key = value_impl.to_value(max_key, None) expected_max_key_type = computation_types.at_server(tf.int32) if not expected_max_key_type.is_assignable_from(max_key.type_signature): _select_parameter_mismatch( max_key.type_signature, 'a 32-bit unsigned integer placed at server', 'max_key', secure, expected_type=expected_max_key_type) server_val = value_impl.to_value(server_val, None) server_val = value_utils.ensure_federated_value(server_val, label='server_val') expected_server_val_type = computation_types.at_server( computation_types.AbstractType('T')) if (not server_val.type_signature.is_federated() or not server_val.type_signature.placement.is_server()): _select_parameter_mismatch(server_val.type_signature, 'a value placed at server', 'server_val', secure, expected_type=expected_server_val_type) select_fn_param_type = computation_types.to_type( [server_val.type_signature.member, tf.int32]) select_fn = value_impl.to_value(select_fn, None, parameter_type_hint=select_fn_param_type) expected_select_fn_type = computation_types.FunctionType( select_fn_param_type, computation_types.AbstractType('U')) if (not select_fn.type_signature.is_function() or not select_fn.type_signature.parameter.is_assignable_from( select_fn_param_type)): _select_parameter_mismatch(select_fn.type_signature, 'a function from state and key to result', 'select_fn', secure, expected_type=expected_select_fn_type) comp = building_block_factory.create_federated_select( client_keys.comp, max_key.comp, server_val.comp, select_fn.comp, secure) comp = _bind_comp_as_reference(comp) return value_impl.Value(comp)
def test_succeeds_under_tuple(self): t1 = self.func_with_param( computation_types.StructType( [computation_types.AbstractType('T1')] * 2)) t2 = self.func_with_param( computation_types.StructType([ computation_types.TensorType(tf.int32), computation_types.TensorType(tf.int32) ])) type_analysis.check_concrete_instance_of(t2, t1)
def test_fails_under_tuple_conflicting_concrete_types(self): t1 = self.func_with_param( computation_types.StructType( [computation_types.AbstractType('T1')] * 2)) t2 = self.func_with_param( computation_types.StructType([ computation_types.TensorType(tf.int32), computation_types.TensorType(tf.float32) ])) with self.assertRaises(type_analysis.MismatchedConcreteTypesError): type_analysis.check_concrete_instance_of(t2, t1)
def test_abstract_federated_types_succeeds(self): t1 = self.func_with_param( computation_types.FederatedType( [computation_types.AbstractType('T1')] * 2, placements.CLIENTS, all_equal=True)) t2 = self.func_with_param( computation_types.FederatedType([tf.int32] * 2, placements.CLIENTS, all_equal=True)) type_analysis.check_concrete_instance_of(t2, t1)
def test_transforms_abstract_type(self): orig_type = computation_types.AbstractType('T') expected_type = computation_types.TensorType(tf.float32) result_type, mutated = type_transformations.transform_type_postorder( orig_type, _convert_abstract_type_to_tensor) noop_type, not_mutated = type_transformations.transform_type_postorder( orig_type, _convert_placement_type_to_tensor) self.assertEqual(result_type, expected_type) self.assertEqual(noop_type, orig_type) self.assertTrue(mutated) self.assertFalse(not_mutated)
def test_abstract_can_be_concretized_fails_on_different_placements(self): t1 = self.func_with_param( computation_types.FederatedType( [computation_types.AbstractType('T1')] * 2, placements.CLIENTS, all_equal=True)) t2 = self.func_with_param( computation_types.FederatedType([tf.int32] * 2, placements.SERVER, all_equal=True)) with self.assertRaises(type_analysis.MismatchedStructureError): type_analysis.check_concrete_instance_of(t2, t1)
def test_abstract_parameters_contravariant(self): struct = lambda name: computation_types.StructType([(name, tf.int32)]) unnamed = struct(None) concrete = computation_types.FunctionType( computation_types.StructType([ unnamed, computation_types.FunctionType(struct('bar'), unnamed) ]), struct('foo')) abstract = computation_types.AbstractType('A') generic = computation_types.FunctionType( computation_types.StructType( [abstract, computation_types.FunctionType(abstract, abstract)]), abstract) type_analysis.check_concrete_instance_of(concrete, generic)
def test_raises_not_implemented_error_with_unimplemented_intrinsic(self): executor = create_test_executor() # `whimsy_intrinsic` definition is needed to allow lookup. whimsy_intrinsic = intrinsic_defs.IntrinsicDef( 'WHIMSY_INTRINSIC', 'whimsy_intrinsic', computation_types.AbstractType('T')) type_signature = computation_types.TensorType(tf.int32) comp = pb.Computation( intrinsic=pb.Intrinsic(uri='whimsy_intrinsic'), type=type_serialization.serialize_type(type_signature)) del whimsy_intrinsic comp = self.run_sync(executor.create_value(comp)) with self.assertRaises(NotImplementedError): self.run_sync(executor.create_call(comp))
def test_executor_call_unsupported_intrinsic(self): # `whimsy_intrinsic` definition is needed to allow successful lookup. whimsy_intrinsic = intrinsic_defs.IntrinsicDef( 'WHIMSY_INTRINSIC', 'whimsy_intrinsic', computation_types.AbstractType('T')) type_signature = computation_types.TensorType(tf.int32) comp = pb.Computation( type=type_serialization.serialize_type(type_signature), intrinsic=pb.Intrinsic(uri='whimsy_intrinsic')) del whimsy_intrinsic factory = federated_composing_strategy.FederatedComposingStrategy.factory( _create_bottom_stack(), [_create_worker_stack()]) executor = federating_executor.FederatingExecutor( factory, _create_bottom_stack()) v1 = asyncio.run(executor.create_value(comp)) with self.assertRaises(NotImplementedError): asyncio.run(executor.create_call(v1))
def test_fails_with_bad_types(self): function = computation_types.FunctionType( None, computation_types.TensorType(tf.int32)) federated = computation_types.FederatedType(tf.int32, placements.CLIENTS) tuple_on_function = computation_types.StructType([federated, function]) def foo(x): # pylint: disable=unused-variable del x # Unused. with self.assertRaisesRegex( TypeError, r'you have attempted to create one with the type {int32}@CLIENTS' ): computation_wrapper_instances.tensorflow_wrapper(foo, federated) # pylint: disable=anomalous-backslash-in-string with self.assertRaisesRegex( TypeError, r'you have attempted to create one with the type \( -> int32\)' ): computation_wrapper_instances.tensorflow_wrapper(foo, function) with self.assertRaisesRegex( TypeError, r'you have attempted to create one with the type placement'): computation_wrapper_instances.tensorflow_wrapper( foo, computation_types.PlacementType()) with self.assertRaisesRegex( TypeError, r'you have attempted to create one with the type T'): computation_wrapper_instances.tensorflow_wrapper( foo, computation_types.AbstractType('T')) with self.assertRaisesRegex( TypeError, r'you have attempted to create one with the type <{int32}@CLIENTS,\( ' '-> int32\)>'): computation_wrapper_instances.tensorflow_wrapper( foo, tuple_on_function)
def test_succeeds_abstract_type_under_sequence_type(self): t1 = self.func_with_param( computation_types.SequenceType( computation_types.AbstractType('T'))) t2 = self.func_with_param(computation_types.SequenceType(tf.int32)) type_analysis.check_concrete_instance_of(t2, t1)
class CheckWellFormedTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters([ ('abstract_type', lambda: computation_types.AbstractType('T')), ('federated_type', lambda: computation_types.FederatedType(tf.int32, placements.CLIENTS) ), ('function_type', lambda: computation_types.FunctionType(tf.int32, tf.int32)), ('named_tuple_type', lambda: computation_types.StructType([tf.int32] * 3)), ('placement_type', computation_types.PlacementType), ('sequence_type', lambda: computation_types.SequenceType(tf.int32)), ('tensor_type', lambda: computation_types.TensorType(tf.int32)), ]) # pyformat: enable def test_does_not_raise_type_error(self, create_type_signature): try: create_type_signature() except TypeError: self.fail('Raised TypeError unexpectedly.') @parameterized.named_parameters([ ( 'federated_function_type', lambda: computation_types.FederatedType( # pylint: disable=g-long-lambda computation_types.FunctionType(tf.int32, tf.int32), placements. CLIENTS)), ( 'federated_federated_type', lambda: computation_types.FederatedType( # pylint: disable=g-long-lambda computation_types.FederatedType(tf.int32, placements.CLIENTS), placements.CLIENTS)), ( 'sequence_sequence_type', lambda: computation_types.SequenceType( # pylint: disable=g-long-lambda computation_types.SequenceType([tf.int32]))), ( 'sequence_federated_type', lambda: computation_types.SequenceType( # pylint: disable=g-long-lambda computation_types.FederatedType(tf.int32, placements.CLIENTS)) ), ( 'tuple_federated_function_type', lambda: computation_types.StructType([ # pylint: disable=g-long-lambda computation_types.FederatedType( computation_types.FunctionType(tf.int32, tf.int32), placements.CLIENTS) ])), ( 'tuple_federated_federated_type', lambda: computation_types.StructType([ # pylint: disable=g-long-lambda computation_types.FederatedType( computation_types.FederatedType( tf.int32, placements.CLIENTS), placements.CLIENTS) ])), ( 'federated_tuple_function_type', lambda: computation_types.FederatedType( # pylint: disable=g-long-lambda computation_types.StructType([ computation_types.FunctionType(tf.int32, tf.int32) ]), placements.CLIENTS)), ]) # pyformat: enable def test_raises_type_error(self, create_type_signature): with self.assertRaises(TypeError): create_type_signature()
def test_succeeds_single_function_type(self): t1 = computation_types.FunctionType( *[computation_types.AbstractType('T')] * 2) t2 = computation_types.FunctionType(tf.int32, tf.int32) type_analysis.check_concrete_instance_of(t2, t1)
def test_identity(self): t1 = computation_types.AbstractType('T') t2 = computation_types.AbstractType('T') self.assertIs(t1, t2)
def test_returns_string_for_abstract_type(self): type_spec = computation_types.AbstractType('T') self.assertEqual(type_spec.compact_representation(), 'T') self.assertEqual(type_spec.formatted_representation(), 'T')
def test_equality(self): t1 = computation_types.AbstractType('T') t2 = computation_types.AbstractType('T') t3 = computation_types.AbstractType('U') self.assertEqual(t1, t2) self.assertNotEqual(t1, t3)
def test_is_assignable_from(self): t1 = computation_types.AbstractType('T1') t2 = computation_types.AbstractType('T2') with self.assertRaises(TypeError): t1.is_assignable_from(t2)
def test_raises_with_abstract_type_as_first_arg(self): t1 = computation_types.AbstractType('T1') t2 = computation_types.TensorType(tf.int32) with self.assertRaises(type_analysis.NotConcreteTypeError): type_analysis.check_concrete_instance_of(t1, t2)
def test_construction(self): t1 = computation_types.AbstractType('T1') self.assertEqual(repr(t1), 'AbstractType(\'T1\')') self.assertEqual(str(t1), 'T1') self.assertEqual(t1.label, 'T1') self.assertRaises(TypeError, computation_types.AbstractType, 10)
def test_with_single_abstract_type_and_tuple_type(self): t1 = self.func_with_param(computation_types.AbstractType('T1')) t2 = self.func_with_param(computation_types.StructType([tf.int32])) type_analysis.check_concrete_instance_of(t2, t1)
def test_raises_with_abstract_type_in_first_and_second_argument(self): t1 = computation_types.AbstractType('T1') t2 = computation_types.AbstractType('T2') with self.assertRaises(type_analysis.NotConcreteTypeError): type_analysis.check_concrete_instance_of(t2, t1)
# # @federated_computation # def federated_aggregate(x, zero, accumulate, merge, report): # a = generic_partial_reduce(x, zero, accumulate, INTERMEDIATE_AGGREGATORS) # b = generic_reduce(a, zero, merge, SERVER) # c = generic_map(report, b) # return c # # Actual implementations might vary. # # Type signature: <{T}@CLIENTS,U,(<U,T>->U),(<U,U>->U),(U->R)> -> R@SERVER FEDERATED_AGGREGATE = IntrinsicDef( 'FEDERATED_AGGREGATE', 'federated_aggregate', computation_types.FunctionType(parameter=[ computation_types.at_clients(computation_types.AbstractType('T')), computation_types.AbstractType('U'), type_factory.reduction_op(computation_types.AbstractType('U'), computation_types.AbstractType('T')), type_factory.binary_op(computation_types.AbstractType('U')), computation_types.FunctionType(computation_types.AbstractType('U'), computation_types.AbstractType('R')) ], result=computation_types.at_server( computation_types.AbstractType('R'))), aggregation_kind=AggregationKind.DEFAULT) # Applies a given function to a value on the server. # # Type signature: <(T->U),T@SERVER> -> U@SERVER FEDERATED_APPLY = IntrinsicDef(
def test_with_single_abstract_type_and_tensor_type(self): t1 = computation_types.AbstractType('T1') t2 = computation_types.TensorType(tf.int32) type_analysis.check_concrete_instance_of(t2, t1)
class CheckAllAbstractTypesAreBoundTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters([ ('tensor_type', computation_types.TensorType(tf.int32)), ('function_type_with_no_arg', computation_types.FunctionType(None, tf.int32)), ('function_type_with_int_arg', computation_types.FunctionType(tf.int32, tf.int32)), ('function_type_with_abstract_arg', computation_types.FunctionType(computation_types.AbstractType('T'), computation_types.AbstractType('T'))), ('tuple_tuple_function_type_with_abstract_arg', computation_types.StructType([ computation_types.StructType([ computation_types.FunctionType( computation_types.AbstractType('T'), computation_types.AbstractType('T')), ]) ])), ('function_type_with_unbound_function_arg', computation_types.FunctionType( computation_types.FunctionType( None, computation_types.AbstractType('T')), computation_types.AbstractType('T'))), ('function_type_with_sequence_arg', computation_types.FunctionType( computation_types.SequenceType( computation_types.AbstractType('T')), tf.int32)), ('function_type_with_two_abstract_args', computation_types.FunctionType( computation_types.StructType([ computation_types.AbstractType('T'), computation_types.AbstractType('U'), ]), computation_types.StructType([ computation_types.AbstractType('T'), computation_types.AbstractType('U'), ]))), ]) # pyformat: enable def test_does_not_raise_type_error(self, type_spec): try: type_analysis.check_all_abstract_types_are_bound(type_spec) except TypeError: self.fail('Raised TypeError unexpectedly.') # pyformat: disable @parameterized.named_parameters([ ('abstract_type', computation_types.AbstractType('T')), ('function_type_with_no_arg', computation_types.FunctionType(None, computation_types.AbstractType('T'))), ('function_type_with_int_arg', computation_types.FunctionType(tf.int32, computation_types.AbstractType('T'))), ('function_type_with_abstract_arg', computation_types.FunctionType(computation_types.AbstractType('T'), computation_types.AbstractType('U'))), ]) # pyformat: enable def test_raises_type_error(self, type_spec): with self.assertRaises(TypeError): type_analysis.check_all_abstract_types_are_bound(type_spec)