def __init__(self, test, uri, value, type_spec): self._test = test self._uri = uri self._value = value self._type_spec = computation_types.to_type(type_spec)
def create( self, value_type: factory.ValueType) -> aggregation_process.AggregationProcess: # Validate input args and value_type and parse out the TF dtypes. if value_type.is_tensor(): tf_dtype = value_type.dtype elif (value_type.is_struct_with_python() and type_analysis.is_structure_of_tensors(value_type)): tf_dtype = type_conversions.structure_from_tensor_type_tree( lambda x: x.dtype, value_type) else: raise TypeError('Expected `value_type` to be `TensorType` or ' '`StructWithPythonType` containing only `TensorType`. ' f'Found type: {repr(value_type)}') # Check that all values are floats. if not type_analysis.is_structure_of_floats(value_type): raise TypeError('Component dtypes of `value_type` must all be floats. ' f'Found {repr(value_type)}.') if self._distortion_aggregation_factory is not None: distortion_aggregation_process = self._distortion_aggregation_factory.create( computation_types.to_type(tf.float32)) @tensorflow_computation.tf_computation(value_type, tf.float32) def discretize_fn(value, step_size): return _discretize_struct(value, step_size) @tensorflow_computation.tf_computation(discretize_fn.type_signature.result, tf.float32) def undiscretize_fn(value, step_size): return _undiscretize_struct(value, step_size, tf_dtype) @tensorflow_computation.tf_computation(value_type, tf.float32) def distortion_measurement_fn(value, step_size): reconstructed_value = undiscretize_fn( discretize_fn(value, step_size), step_size) err = tf.nest.map_structure(tf.subtract, reconstructed_value, value) squared_err = tf.nest.map_structure(tf.square, err) flat_squared_errs = [ tf.cast(tf.reshape(t, [-1]), tf.float32) for t in tf.nest.flatten(squared_err) ] all_squared_errs = tf.concat(flat_squared_errs, axis=0) mean_squared_err = tf.reduce_mean(all_squared_errs) return mean_squared_err inner_agg_process = self._inner_agg_factory.create( discretize_fn.type_signature.result) @federated_computation.federated_computation() def init_fn(): state = collections.OrderedDict( step_size=intrinsics.federated_value(self._step_size, placements.SERVER), inner_agg_process=inner_agg_process.initialize()) return intrinsics.federated_zip(state) @federated_computation.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type)) def next_fn(state, value): server_step_size = state['step_size'] client_step_size = intrinsics.federated_broadcast(server_step_size) discretized_value = intrinsics.federated_map(discretize_fn, (value, client_step_size)) inner_state = state['inner_agg_process'] inner_agg_output = inner_agg_process.next(inner_state, discretized_value) undiscretized_agg_value = intrinsics.federated_map( undiscretize_fn, (inner_agg_output.result, server_step_size)) new_state = collections.OrderedDict( step_size=server_step_size, inner_agg_process=inner_agg_output.state) measurements = collections.OrderedDict( deterministic_discretization=inner_agg_output.measurements) if self._distortion_aggregation_factory is not None: distortions = intrinsics.federated_map(distortion_measurement_fn, (value, client_step_size)) aggregate_distortion = distortion_aggregation_process.next( distortion_aggregation_process.initialize(), distortions).result measurements['distortion'] = aggregate_distortion return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=undiscretized_agg_value, measurements=intrinsics.federated_zip(measurements)) return aggregation_process.AggregationProcess(init_fn, next_fn)
def serialize_jax_computation(traced_fn, arg_fn, parameter_type, context_stack): """Serializes a Python function containing JAX code as a TFF computation. Args: traced_fn: The Python function containing JAX code to be traced by JAX and serialized as a TFF computation containing XLA code. arg_fn: An unpacking function that takes a TFF argument, and returns a combo of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed by `function_utils.create_argument_unpacking_fn`). parameter_type: An instance of `computation_types.Type` that represents the TFF type of the computation parameter, or `None` if the function does not take any parameters. context_stack: The context stack to use during serialization. Returns: An instance of `pb.Computation` with the constructed computation. Raises: TypeError: if the arguments are of the wrong types. """ py_typecheck.check_callable(traced_fn) py_typecheck.check_callable(arg_fn) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if parameter_type is not None: parameter_type = computation_types.to_type(parameter_type) packed_arg = _tff_type_to_xla_serializer_arg(parameter_type) else: packed_arg = None args, kwargs = arg_fn(packed_arg) # While the fake parameters are fed via args/kwargs during serialization, # it is possible for them to get reordered in the actual generated XLA code. # We use here the same flattening function as that one, which is used by # the JAX serializer to determine the ordering and allow it to be captured # in the parameter binding. We do not need to do anything special for the # results, since the results, if multiple, are always returned as a tuple. flattened_obj, _ = jax.tree_util.tree_flatten((args, kwargs)) tensor_indexes = list(np.argsort([x.tensor_index for x in flattened_obj])) context = jax_computation_context.JaxComputationContext() with context_stack.install(context): tracer_callable = jax.xla_computation(traced_fn, tuple_args=True, return_shape=True) compiled_xla, returned_shape = tracer_callable(*args, **kwargs) if isinstance(returned_shape, jax.ShapeDtypeStruct): returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor( returned_shape) else: returned_type_spec = computation_types.to_type( structure.map_structure( _jax_shape_dtype_struct_to_tff_tensor, structure.from_container(returned_shape, recursive=True))) computation_type = computation_types.FunctionType(parameter_type, returned_type_spec) return xla_serialization.create_xla_tff_computation( compiled_xla, tensor_indexes, computation_type)
def save(model: model_lib.Model, path: str, input_type=None) -> None: """Serializes `model` as a TensorFlow SavedModel to `path`. The resulting SavedModel will contain the default serving signature, which can be used with the TFLite converter to create a TFLite flatbuffer for inference. NOTE: The model returned by `tff.learning.models.load` will _not_ be the same Python type as the saved model. If the model serialized using this method is a subclass of `tff.learning.Model`, that subclass is _not_ returned. All method behavior is retained, but the Python type does not cross serialization boundaries. The return type of `metric_finalizers` will be an OrderedDict of str to `tff.tf_computation` (annotated TFF computations) which could be different from that of the model before serialization. Args: model: The `tff.learning.Model` to save. path: The `str` directory path to serialize the model to. input_type: An optional structure of `tf.TensorSpec`s representing the expected input of `model.predict_on_batch`, to override reading from `model.input_spec`. Typically this will be similar to `model.input_spec`, with any example labels removed. If None, default to `model.input_spec['x']` if the input_spec is a mapping, otherwise default to `model.input_spec[0]`. """ py_typecheck.check_type(model, model_lib.Model) py_typecheck.check_type(path, str) if not path: raise ValueError('`path` must be a non-empty string, cannot serialize ' 'models without an output path.') if isinstance(model, _LoadedSavedModel): # If we're saving a previously loaded model, we can simply use the module # already internal to the Model. _save_tensorflow_module(model._loaded_module, path) # pylint: disable=protected-access return m = tf.Module() # We prefixed with `tff_` because `trainable_variables` is an attribute # reserved by `tf.Module`. m.tff_trainable_variables = model.trainable_variables m.tff_non_trainable_variables = model.non_trainable_variables m.tff_local_variables = model.local_variables # Serialize forward_pass. We must get two concrete versions of the # function, as the `training` argument is a Python value that changes the # graph computation. We serialize the output type so that we can repack the # flattened values after loaded the saved model. forward_pass_training = _make_concrete_flat_output_fn( functools.partial(model.forward_pass, training=True), model.input_spec) m.flat_forward_pass_training = forward_pass_training[0] m.forward_pass_training_type_spec = tf.Variable( forward_pass_training[1].SerializeToString(deterministic=True), trainable=False) forward_pass_inference = _make_concrete_flat_output_fn( functools.partial(model.forward_pass, training=False), model.input_spec) m.flat_forward_pass_inference = forward_pass_inference[0] m.forward_pass_inference_type_spec = tf.Variable( forward_pass_inference[1].SerializeToString(deterministic=True), trainable=False) # Get model prediction input type. If `None`, default to assuming the 'x' key # or first element of the model input spec is the input. if input_type is None: if isinstance(model.input_spec, collections.abc.Mapping): input_type = model.input_spec['x'] else: input_type = model.input_spec[0] # Serialize predict_on_batch. We must get two concrete versions of the # function, as the `training` argument is a Python value that changes the # graph computation. predict_on_batch_training = _make_concrete_flat_output_fn( functools.partial(model.predict_on_batch, training=True), input_type) m.predict_on_batch_training = predict_on_batch_training[0] m.predict_on_batch_training_type_spec = tf.Variable( predict_on_batch_training[1].SerializeToString(deterministic=True), trainable=False) predict_on_batch_inference = _make_concrete_flat_output_fn( functools.partial(model.predict_on_batch, training=False), input_type) m.predict_on_batch_inference = predict_on_batch_inference[0] m.predict_on_batch_inference_type_spec = tf.Variable( predict_on_batch_inference[1].SerializeToString(deterministic=True), trainable=False) # Serialize the report_local_unfinalized_metrics tf.function. m.report_local_unfinalized_metrics = ( model.report_local_unfinalized_metrics.get_concrete_function()) # Serialize the metric_finalizers as `tf.Variable`s. m.serialized_metric_finalizers = collections.OrderedDict() def serialize_metric_finalizer(finalizer, metric_type): finalizer_computation = tensorflow_computation.tf_computation( finalizer, metric_type) return tf.Variable(computation_serialization.serialize_computation( finalizer_computation).SerializeToString(deterministic=True), trainable=False) for metric_name, finalizer in model.metric_finalizers().items(): metric_type = type_conversions.type_from_tensors( model.report_local_unfinalized_metrics()[metric_name]) m.serialized_metric_finalizers[ metric_name] = serialize_metric_finalizer(finalizer, metric_type) # Serialize the TFF values as string variables that contain the serialized # protos from the computation or the type. m.serialized_input_spec = tf.Variable(type_serialization.serialize_type( computation_types.to_type( model.input_spec)).SerializeToString(deterministic=True), trainable=False) # Serialize the reset_metrics tf.function. try: m.reset_metrics = (model.reset_metrics.get_concrete_function()) except NotImplementedError: m.reset_metrics = None _save_tensorflow_module(m, path)
def create_whimsy_intrinsic_def_federated_secure_sum_bitwidth(): value = intrinsic_defs.FEDERATED_SECURE_SUM_BITWIDTH type_signature = computation_types.FunctionType([ computation_types.at_clients(tf.int32), tf.int32, ], computation_types.at_server(tf.int32)) return value, type_signature _WHIMSY_SELECT_CLIENT_KEYS_TYPE = computation_types.at_clients( computation_types.TensorType(tf.int32, [3])) _WHIMSY_SELECT_MAX_KEY_TYPE = computation_types.at_server(tf.int32) _WHIMSY_SELECT_SERVER_STATE_TYPE = computation_types.at_server(tf.string) _WHIMSY_SELECTED_TYPE = computation_types.to_type((tf.string, tf.int32)) _WHIMSY_SELECT_SELECT_FN_TYPE = computation_types.FunctionType( (tf.string, tf.int32), _WHIMSY_SELECTED_TYPE) _WHIMSY_SELECT_RESULT_TYPE = computation_types.at_clients( computation_types.SequenceType(_WHIMSY_SELECTED_TYPE)) _WHIMSY_SELECT_TYPE = computation_types.FunctionType([ _WHIMSY_SELECT_CLIENT_KEYS_TYPE, _WHIMSY_SELECT_MAX_KEY_TYPE, _WHIMSY_SELECT_SERVER_STATE_TYPE, _WHIMSY_SELECT_SELECT_FN_TYPE, ], _WHIMSY_SELECT_RESULT_TYPE) _WHIMSY_SELECT_NUM_CLIENTS = 3 def create_whimsy_intrinsic_def_federated_secure_select(): return intrinsic_defs.FEDERATED_SECURE_SELECT, _WHIMSY_SELECT_TYPE
def test_structure_of_tensors(self): example_type = computation_types.to_type( collections.OrderedDict( a=TensorType(tf.int32, [3]), b=[TensorType(tf.float32), TensorType(tf.bool)])) merge_computation = sampling._build_merge_samples_computation( example_type, sample_size=5) reservoir_type = sampling._build_reservoir_type(example_type) expected_type = FunctionType(parameter=collections.OrderedDict( a=reservoir_type, b=reservoir_type), result=reservoir_type) self.assert_types_identical(merge_computation.type_signature, expected_type) reservoir_a = sampling._build_initial_sample_reservoir(example_type, seed=TEST_SEED) reservoir_a['random_values'] = [1, 3, 5] reservoir_a['samples'] = collections.OrderedDict( a=[[0, 1, 2], [1, 2, 3], [2, 3, 4]], b=[[0.0, 1.0, 2.0], [True, False, True]]) with self.subTest('downsample'): reservoir_b = sampling._build_initial_sample_reservoir( example_type, seed=TEST_SEED + 1) reservoir_b['random_values'] = [2, 4, 6, 8] reservoir_b['samples'] = collections.OrderedDict( a=[[0, -1, -2], [-1, -2, -3], [-2, -3, -4], [-3, -4, -5]], b=[[-1., -2., -3., -4.], [True, False, False, True]]) merged_reservoir = merge_computation(reservoir_a, reservoir_b) self.assertAllEqual( merged_reservoir, collections.OrderedDict( # Arbitrarily take seeds from `a`, discarded later. random_seed=tf.convert_to_tensor((TEST_SEED, TEST_SEED)), random_values=[3, 5, 4, 6, 8], samples=collections.OrderedDict( a=[[1, 2, 3], [2, 3, 4], [-1, -2, -3], [-2, -3, -4], [-3, -4, -5]], b=[[1., 2., -2., -3., -4.], [False, True, False, False, True]]))) with self.subTest('keep_all'): reservoir_b = sampling._build_initial_sample_reservoir( example_type, seed=TEST_SEED + 1) reservoir_b['random_values'] = [2] reservoir_b['samples'] = collections.OrderedDict(a=[[0, -1, -2]], b=[[-1.0], [True]]) # We select the value from reservoir_b because its random_value was # higher. merged_reservoir = merge_computation(reservoir_a, reservoir_b) self.assertAllEqual( merged_reservoir, collections.OrderedDict( # Arbitrarily take seeds from `a`, discarded later. random_seed=tf.convert_to_tensor((TEST_SEED, TEST_SEED)), random_values=[1, 3, 5, 2], samples=collections.OrderedDict( a=[[0, 1, 2], [1, 2, 3], [2, 3, 4], [-0, -1, -2]], b=[[0., 1., 2., -1.], [True, False, True, True]]))) with self.subTest('tie_breakers'): # In case of tie, we take the as many values from `a` first. reservoir_b = sampling._build_initial_sample_reservoir( example_type, seed=TEST_SEED) reservoir_b['random_values'] = [5, 5, 5, 5, 5] # all tied with `a` reservoir_b['samples'] = collections.OrderedDict(a=[[-1, -1, -1]] * 5, b=[[-1] * 5, [False] * 5]) merged_reservoir = merge_computation(reservoir_a, reservoir_b) self.assertAllEqual( merged_reservoir, collections.OrderedDict(random_seed=tf.convert_to_tensor( (TEST_SEED, TEST_SEED)), random_values=[5, 5, 5, 5, 5], samples=collections.OrderedDict( a=[[2, 3, 4]] + [[-1, -1, -1]] * 4, b=[[2] + [-1] * 4, [True] + [False] * 4])))
def to_value( arg: Any, type_spec, parameter_type_hint=None, ) -> Value: """Converts the argument into an instance of the abstract class `tff.Value`. Instances of `tff.Value` represent TFF values that appear internally in federated computations. This helper function can be used to wrap a variety of Python objects as `tff.Value` instances to allow them to be passed as arguments, used as functions, or otherwise manipulated within bodies of federated computations. At the moment, the supported types include: * Simple constants of `str`, `int`, `float`, and `bool` types, mapped to values of a TFF tensor type. * Numpy arrays (`np.ndarray` objects), also mapped to TFF tensors. * Dictionaries (`collections.OrderedDict` and unordered `dict`), `list`s, `tuple`s, `namedtuple`s, and `Struct`s, all of which are mapped to TFF tuple type. * Computations (constructed with either the `tff.tf_computation` or with the `tff.federated_computation` decorator), typically mapped to TFF functions. * Placement literals (`tff.CLIENTS`, `tff.SERVER`), mapped to values of the TFF placement type. This function is also invoked when attempting to execute a TFF computation. All arguments supplied in the invocation are converted into TFF values prior to execution. The types of Python objects that can be passed as arguments to computations thus matches the types listed here. Args: arg: An instance of one of the Python types that are convertible to TFF values (instances of `tff.Value`). type_spec: An optional type specifier that allows for disambiguating the target type (e.g., when two TFF types can be mapped to the same Python representations). If not specified, TFF tried to determine the type of the TFF value automatically. parameter_type_hint: An optional `tff.Type` or value convertible to it by `tff.to_type()` which specifies an argument type to use in the case that `arg` is a `function_utils.PolymorphicFunction`. Returns: An instance of `tff.Value` as described above. Raises: TypeError: if `arg` is of an unsupported type, or of a type that does not match `type_spec`. Raises explicit error message if TensorFlow constructs are encountered, as TensorFlow code should be sealed away from TFF federated context. """ if type_spec is not None: type_spec = computation_types.to_type(type_spec) if isinstance(arg, Value): result = arg elif isinstance(arg, building_blocks.ComputationBuildingBlock): result = Value(arg) elif isinstance(arg, placements.PlacementLiteral): result = Value(building_blocks.Placement(arg)) elif isinstance( arg, (computation_base.Computation, function_utils.PolymorphicFunction)): if isinstance(arg, function_utils.PolymorphicFunction): if parameter_type_hint is None: raise TypeError( 'Polymorphic computations cannot be converted to `tff.Value`s ' 'without a type hint. Consider explicitly specifying the ' 'argument types of a computation before passing it to a ' 'function that requires a `tff.Value` (such as a TFF intrinsic ' 'like `federated_map`). If you are a TFF developer and think ' 'this should be supported, consider providing `parameter_type_hint` ' 'as an argument to the encompassing `to_value` conversion.' ) parameter_type_hint = computation_types.to_type( parameter_type_hint) arg = arg.fn_for_argument_type(parameter_type_hint) py_typecheck.check_type(arg, computation_base.Computation) result = Value(arg.to_compiled_building_block()) elif type_spec is not None and type_spec.is_sequence(): result = _wrap_sequence_as_value(arg, type_spec.element) elif isinstance(arg, structure.Struct): items = structure.iter_elements(arg) result = _dictlike_items_to_value(items, type_spec, None) elif py_typecheck.is_named_tuple(arg): items = arg._asdict().items() result = _dictlike_items_to_value(items, type_spec, type(arg)) elif py_typecheck.is_attrs(arg): items = attr.asdict(arg, dict_factory=collections.OrderedDict, recurse=False).items() result = _dictlike_items_to_value(items, type_spec, type(arg)) elif isinstance(arg, dict): if isinstance(arg, collections.OrderedDict): items = arg.items() else: items = sorted(arg.items()) result = _dictlike_items_to_value(items, type_spec, type(arg)) elif isinstance(arg, (tuple, list)): items = zip(itertools.repeat(None), arg) result = _dictlike_items_to_value(items, type_spec, type(arg)) elif isinstance(arg, tensorflow_utils.TENSOR_REPRESENTATION_TYPES): result = _wrap_constant_as_value(arg) elif isinstance(arg, (tf.Tensor, tf.Variable)): raise TypeError( 'TensorFlow construct {} has been encountered in a federated ' 'context. TFF does not support mixing TF and federated orchestration ' 'code. Please wrap any TensorFlow constructs with ' '`tff.tf_computation`.'.format(arg)) else: raise TypeError( 'Unable to interpret an argument of type {} as a `tff.Value`.'. format(py_typecheck.type_string(type(arg)))) py_typecheck.check_type(result, Value) if (type_spec is not None and not type_spec.is_assignable_from(result.type_signature)): raise TypeError( 'The supplied argument maps to TFF type {}, which is incompatible with ' 'the requested type {}.'.format(result.type_signature, type_spec)) return result
def test_increasing_zero_clip_sum(self): # Tests when zeroing and clipping are performed with non-integer clips. # Zeroing norm grows by 0.75 each time, clipping norm grows by 0.25. @computations.federated_computation(_float_at_server, _float_at_clients) def zeroing_next_fn(state, value): del value return intrinsics.federated_map( computations.tf_computation(lambda x: x + 0.75, tf.float32), state) @computations.federated_computation(_float_at_server, _float_at_clients) def clipping_next_fn(state, value): del value return intrinsics.federated_map( computations.tf_computation(lambda x: x + 0.25, tf.float32), state) zeroing_norm_process = estimation_process.EstimationProcess( _test_init_fn, zeroing_next_fn, _test_report_fn) clipping_norm_process = estimation_process.EstimationProcess( _test_init_fn, clipping_next_fn, _test_report_fn) factory = robust.zeroing_factory(zeroing_norm_process, _clipped_sum(clipping_norm_process)) value_type = computation_types.to_type(tf.float32) process = factory.create(value_type) state = process.initialize() client_data = [1.0, 2.0, 3.0] output = process.next(state, client_data) self.assertAllClose(1.0, output.measurements['zeroing_norm']) self.assertAllClose(1.0, output.measurements['zeroing']['clipping_norm']) self.assertEqual(2, output.measurements['zeroed_count']) self.assertEqual(0, output.measurements['zeroing']['clipped_count']) self.assertAllClose(1.0, output.result) output = process.next(output.state, client_data) self.assertAllClose(1.75, output.measurements['zeroing_norm']) self.assertAllClose(1.25, output.measurements['zeroing']['clipping_norm']) self.assertEqual(2, output.measurements['zeroed_count']) self.assertEqual(0, output.measurements['zeroing']['clipped_count']) self.assertAllClose(1.0, output.result) output = process.next(output.state, client_data) self.assertAllClose(2.5, output.measurements['zeroing_norm']) self.assertAllClose(1.5, output.measurements['zeroing']['clipping_norm']) self.assertEqual(1, output.measurements['zeroed_count']) self.assertEqual(1, output.measurements['zeroing']['clipped_count']) self.assertAllClose(2.5, output.result) output = process.next(output.state, client_data) self.assertAllClose(3.25, output.measurements['zeroing_norm']) self.assertAllClose(1.75, output.measurements['zeroing']['clipping_norm']) self.assertEqual(0, output.measurements['zeroed_count']) self.assertEqual(2, output.measurements['zeroing']['clipped_count']) self.assertAllClose(4.5, output.result) output = process.next(output.state, client_data) self.assertAllClose(4.0, output.measurements['zeroing_norm']) self.assertAllClose(2.0, output.measurements['zeroing']['clipping_norm']) self.assertEqual(0, output.measurements['zeroed_count']) self.assertEqual(1, output.measurements['zeroing']['clipped_count']) self.assertAllClose(5.0, output.result)
from tensorflow_federated.python.core.impl.tensorflow_context import tensorflow_computation from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import errors from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.templates import client_works SERVER_INT = computation_types.FederatedType(tf.int32, placements.SERVER) SERVER_FLOAT = computation_types.FederatedType(tf.float32, placements.SERVER) CLIENTS_FLOAT_SEQUENCE = computation_types.FederatedType( computation_types.SequenceType(tf.float32), placements.CLIENTS) CLIENTS_FLOAT = computation_types.FederatedType(tf.float32, placements.CLIENTS) CLIENTS_INT = computation_types.FederatedType(tf.int32, placements.CLIENTS) MODEL_WEIGHTS_TYPE = computation_types.at_clients( computation_types.to_type(model_utils.ModelWeights(tf.float32, ()))) MeasuredProcessOutput = measured_process.MeasuredProcessOutput def server_zero(): return intrinsics.federated_value(0, placements.SERVER) def client_one(): return intrinsics.federated_value(1.0, placements.CLIENTS) def federated_add(a, b): return intrinsics.federated_map( tensorflow_computation.tf_computation(lambda x, y: x + y), (a, b))
def test_tff_value_types_raise_on(self, value_type): ddp_factory = _make_test_factory() value_type = computation_types.to_type(value_type) with self.assertRaisesRegex(TypeError, 'Expected `value_type` to be'): ddp_factory.create(value_type)
def test_component_tensor_dtypes_raise_on(self, value_type): test_factory = _make_test_factory() value_type = computation_types.to_type(value_type) with self.assertRaisesRegex(TypeError, 'must all be integers or floats'): test_factory.create(value_type)
def test_type_properties(self, value_type, mechanism): ddp_factory = _make_test_factory(mechanism=mechanism) self.assertIsInstance(ddp_factory, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = ddp_factory.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) # The state is a nested object with component factory states. Construct # test factories directly and compare the signatures. modsum_f = secure.SecureModularSumFactory(2**15, True) if mechanism == 'distributed_dgauss': dp_query = tfp.DistributedDiscreteGaussianSumQuery( l2_norm_bound=10.0, local_stddev=10.0) else: dp_query = tfp.DistributedSkellamSumQuery(l1_norm_bound=10.0, l2_norm_bound=10.0, local_stddev=10.0) dp_f = differential_privacy.DifferentiallyPrivateFactory( dp_query, modsum_f) discrete_f = discretization.DiscretizationFactory(dp_f) l2clip_f = robust.clipping_factory(clipping_norm=10.0, inner_agg_factory=discrete_f) rot_f = rotation.HadamardTransformFactory(inner_agg_factory=l2clip_f) expected_process = concat.concat_factory(rot_f).create(value_type) # Check init_fn/state. expected_init_type = expected_process.initialize.type_signature expected_state_type = expected_init_type.result actual_init_type = process.initialize.type_signature self.assertTrue(actual_init_type.is_equivalent_to(expected_init_type)) # Check next_fn/measurements. tensor2type = type_conversions.type_from_tensors discrete_state = discrete_f.create( computation_types.to_type(tf.float32)).initialize() dp_query_state = dp_query.initial_global_state() dp_query_metrics_type = tensor2type( dp_query.derive_metrics(dp_query_state)) expected_measurements_type = collections.OrderedDict( l2_clip=robust.NORM_TF_TYPE, scale_factor=tensor2type(discrete_state['scale_factor']), scaled_inflated_l2=tensor2type(dp_query_state.l2_norm_bound), scaled_local_stddev=tensor2type(dp_query_state.local_stddev), actual_num_clients=tf.int32, padded_dim=tf.int32, dp_query_metrics=dp_query_metrics_type) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=expected_state_type, result=computation_types.at_server(value_type), measurements=computation_types.at_server( expected_measurements_type))) actual_next_type = process.next.type_signature self.assertTrue(actual_next_type.is_equivalent_to(expected_next_type)) try: static_assert.assert_not_contains_unsecure_aggregation( process.next) except: # pylint: disable=bare-except self.fail('Factory returned an AggregationProcess containing ' 'non-secure aggregation.')
def test_call_returns_result(self): class TestContext(context_base.Context): def ingest(self, val, type_spec): return val def invoke(self, comp, arg): return 'name={},type={},arg={},unpack={}'.format( comp.name, comp.type_signature.parameter, arg, comp.unpack) class TestContextStack(context_stack_base.ContextStack): def __init__(self): super().__init__() self._context = TestContext() @property def current(self): return self._context def install(self, ctx): del ctx # Unused return self._context context_stack = TestContextStack() class TestFunction(computation_impl.ConcreteComputation): def __init__(self, name, unpack, parameter_type): self._name = name self._unpack = unpack type_signature = computation_types.FunctionType(parameter_type, tf.string) test_proto = pb.Computation( type=type_serialization.serialize_type(type_signature)) super().__init__(test_proto, context_stack, type_signature) @property def name(self): return self._name @property def unpack(self): return self._unpack class TestFunctionFactory(object): def __init__(self): self._count = 0 def __call__(self, parameter_type, unpack): self._count = self._count + 1 return TestFunction(str(self._count), str(unpack), parameter_type) fn = function_utils.PolymorphicComputation(TestFunctionFactory()) self.assertEqual(fn(10), 'name=1,type=<int32>,arg=<10>,unpack=True') self.assertEqual( fn(20, x=True), 'name=2,type=<int32,x=bool>,arg=<20,x=True>,unpack=True') fn_with_bool_arg = fn.fn_for_argument_type( computation_types.to_type(tf.bool)) self.assertEqual( fn_with_bool_arg(True), 'name=3,type=bool,arg=True,unpack=None') self.assertEqual( fn(30, x=40), 'name=4,type=<int32,x=int32>,arg=<30,x=40>,unpack=True') self.assertEqual(fn(50), 'name=1,type=<int32>,arg=<50>,unpack=True') self.assertEqual( fn(0, x=False), 'name=2,type=<int32,x=bool>,arg=<0,x=False>,unpack=True') fn_with_bool_arg = fn.fn_for_argument_type( computation_types.to_type(tf.bool)) self.assertEqual( fn_with_bool_arg(False), 'name=3,type=bool,arg=False,unpack=None') self.assertEqual( fn(60, x=70), 'name=4,type=<int32,x=int32>,arg=<60,x=70>,unpack=True')
te.encoders.hadamard_quantization(8), value_spec) def _one_over_n_encoder_fn(value_spec): return te.encoders.as_gather_encoder( te.core.EncoderComposer(te.testing.PlusOneOverNEncodingStage()).make(), value_spec) def _state_update_encoder_fn(value_spec): return te.encoders.as_gather_encoder( te.core.EncoderComposer(StateUpdateTensorsEncodingStage()).make(), value_spec) _test_struct_type = computation_types.to_type(((tf.float32, (20,)), tf.float32)) class EncodedSumFactoryComputationTest(test_case.TestCase, parameterized.TestCase): @parameterized.named_parameters( ('identity_from_encoder_fn', _identity_encoder_fn), ('uniform_from_encoder_fn', _uniform_encoder_fn), ('hadamard_from_encoder_fn', _hadamard_encoder_fn), ('one_over_n_from_encoder_fn', _one_over_n_encoder_fn), ('state_update_from_encoder_fn', _state_update_encoder_fn), ) def test_type_properties(self, encoder_fn): encoded_f = encoded.EncodedSumFactory(encoder_fn) self.assertIsInstance(encoded_f, factory.UnweightedAggregationFactory)
async def create_value(self, value, type_spec=None): """Creates a value in this executor. The following kinds of `value` are supported as the input: * An instance of TFF computation proto containing one of the supported sequence intrinsics as its sole body. * An instance of eager TF dataset. * Anything that is supported by the target executor (as a pass-through). * A nested structure of any of the above. Args: value: The input for which to create a value. type_spec: An optional TFF type (required if `value` is not an instance of `typed_object.TypedObject`, otherwise it can be `None`). Returns: An instance of `SequenceExecutorValue` that represents the embedded value. """ if type_spec is None: py_typecheck.check_type(value, typed_object.TypedObject) type_spec = value.type_signature else: type_spec = computation_types.to_type(type_spec) if isinstance(type_spec, computation_types.SequenceType): return SequenceExecutorValue( _SequenceFromPayload(value, type_spec), type_spec) if isinstance(value, pb.Computation): value_type = type_serialization.deserialize_type(value.type) value_type.check_equivalent_to(type_spec) which_computation = value.WhichOneof('computation') # NOTE: If not a supported type of intrinsic, we let it fall through and # be handled by embedding in the target executor (below). if which_computation == 'intrinsic': intrinsic_def = intrinsic_defs.uri_to_intrinsic_def( value.intrinsic.uri) if intrinsic_def is None: raise ValueError( 'Encountered an unrecognized intrinsic "{}".'.format( value.intrinsic.uri)) op_type = SequenceExecutor._SUPPORTED_INTRINSIC_TO_SEQUENCE_OP.get( intrinsic_def.uri) if op_type is not None: type_analysis.check_concrete_instance_of( type_spec, intrinsic_def.type_signature) op = op_type(type_spec) return SequenceExecutorValue(op, type_spec) if isinstance(type_spec, computation_types.StructType): if not isinstance(value, structure.Struct): value = structure.from_container(value) elements = structure.flatten(value) element_types = structure.flatten(type_spec) flat_embedded_vals = await asyncio.gather(*[ self.create_value(el, el_type) for el, el_type in zip(elements, element_types) ]) embedded_struct = structure.pack_sequence_as( value, flat_embedded_vals) return await self.create_struct(embedded_struct) target_value = await self._target_executor.create_value( value, type_spec) return SequenceExecutorValue(target_value, type_spec)
def infer_type(arg: Any) -> Optional[computation_types.Type]: """Infers the TFF type of the argument (a `computation_types.Type` instance). Warning: This function is only partially implemented. The kinds of arguments that are currently correctly recognized: * tensors, variables, and data sets * things that are convertible to tensors (including `numpy` arrays, builtin types, as well as `list`s and `tuple`s of any of the above, etc.) * nested lists, `tuple`s, `namedtuple`s, anonymous `tuple`s, `dict`, `OrderedDict`s, `dataclasses`, `attrs` classes, and `tff.TypedObject`s Args: arg: The argument, the TFF type of which to infer. Returns: Either an instance of `computation_types.Type`, or `None` if the argument is `None`. """ if arg is None: return None elif isinstance(arg, typed_object.TypedObject): return arg.type_signature elif tf.is_tensor(arg): # `tf.is_tensor` returns true for some things that are not actually single # `tf.Tensor`s, including `tf.SparseTensor`s and `tf.RaggedTensor`s. if isinstance(arg, tf.RaggedTensor): return computation_types.StructWithPythonType( (('flat_values', infer_type(arg.flat_values)), ('nested_row_splits', infer_type(arg.nested_row_splits))), tf.RaggedTensor) elif isinstance(arg, tf.SparseTensor): return computation_types.StructWithPythonType( (('indices', infer_type(arg.indices)), ('values', infer_type(arg.values)), ('dense_shape', infer_type(arg.dense_shape))), tf.SparseTensor) else: return computation_types.TensorType(arg.dtype.base_dtype, arg.shape) elif isinstance(arg, TF_DATASET_REPRESENTATION_TYPES): element_type = computation_types.to_type(arg.element_spec) return computation_types.SequenceType(element_type) elif isinstance(arg, structure.Struct): return computation_types.StructType([ (k, infer_type(v)) if k else infer_type(v) for k, v in structure.iter_elements(arg) ]) elif py_typecheck.is_attrs(arg): items = named_containers.attrs_class_to_odict(arg).items() return computation_types.StructWithPythonType([(k, infer_type(v)) for k, v in items], type(arg)) elif py_typecheck.is_dataclass(arg): items = named_containers.dataclass_to_odict(arg).items() return computation_types.StructWithPythonType([(k, infer_type(v)) for k, v in items], type(arg)) elif py_typecheck.is_named_tuple(arg): # In Python 3.8 and later `_asdict` no longer return OrderedDict, rather a # regular `dict`. items = collections.OrderedDict(arg._asdict()) return computation_types.StructWithPythonType( [(k, infer_type(v)) for k, v in items.items()], type(arg)) elif isinstance(arg, dict): if isinstance(arg, collections.OrderedDict): items = arg.items() else: items = sorted(arg.items()) return computation_types.StructWithPythonType([(k, infer_type(v)) for k, v in items], type(arg)) elif isinstance(arg, (tuple, list)): elements = [] all_elements_named = True for element in arg: all_elements_named &= py_typecheck.is_name_value_pair(element) elements.append(infer_type(element)) # If this is a tuple of (name, value) pairs, the caller most likely intended # this to be a StructType, so we avoid storing the Python container. if elements and all_elements_named: return computation_types.StructType(elements) else: return computation_types.StructWithPythonType(elements, type(arg)) elif isinstance(arg, str): return computation_types.TensorType(tf.string) elif isinstance(arg, (np.generic, np.ndarray)): return computation_types.TensorType(tf.dtypes.as_dtype(arg.dtype), arg.shape) else: arg_type = type(arg) if arg_type is bool: return computation_types.TensorType(tf.bool) elif arg_type is int: # Chose the integral type based on value. if arg > tf.int64.max or arg < tf.int64.min: raise TypeError( 'No integral type support for values outside range ' f'[{tf.int64.min}, {tf.int64.max}]. Got: {arg}') elif arg > tf.int32.max or arg < tf.int32.min: return computation_types.TensorType(tf.int64) else: return computation_types.TensorType(tf.int32) elif arg_type is float: return computation_types.TensorType(tf.float32) else: # Now fall back onto the heavier-weight processing, as all else failed. # Use make_tensor_proto() to make sure to handle it consistently with # how TensorFlow is handling values (e.g., recognizing int as int32, as # opposed to int64 as in NumPy). try: # TODO(b/113112885): Find something more lightweight we could use here. tensor_proto = tf.make_tensor_proto(arg) return computation_types.TensorType( tf.dtypes.as_dtype(tensor_proto.dtype), tf.TensorShape(tensor_proto.tensor_shape)) except TypeError as e: raise TypeError('Could not infer the TFF type of {}.'.format( py_typecheck.type_string(type(arg)))) from e
def test_inner_federated_type_raises(self): with self.assertRaisesRegex(TypeError, 'FederatedType'): distributors.build_broadcast_process( computation_types.to_type([SERVER_FLOAT, SERVER_FLOAT]))
def _make_wrapper(clipping_norm: Union[float, estimation_process.EstimationProcess], inner_agg_factory: factory.AggregationFactory, make_clip_fn: Callable[[factory.ValueType], computation_base.Computation], attribute_prefix: str) -> factory.AggregationFactory: """Constructs an aggregation factory that applies clip_fn before aggregation. Args: clipping_norm: Either a float (for fixed norm) or an `EstimationProcess` (for adaptive norm) that specifies the norm over which the values should be clipped. inner_agg_factory: A factory specifying the type of aggregation to be done after zeroing. make_clip_fn: A callable that takes a value type and returns a tff.computation specifying the clip operation to apply before aggregation. attribute_prefix: A str for prefixing state and measurement names. Returns: An aggregation factory that applies clip_fn before aggregation. """ py_typecheck.check_type(inner_agg_factory, (factory.UnweightedAggregationFactory, factory.WeightedAggregationFactory)) py_typecheck.check_type(clipping_norm, (float, estimation_process.EstimationProcess)) if isinstance(clipping_norm, float): clipping_norm_process = _constant_process(clipping_norm) else: clipping_norm_process = clipping_norm _check_norm_process(clipping_norm_process, 'clipping_norm_process') # The aggregation factory that will be used to count the number of clipped # values at each iteration. For now we are just creating it here, but in # the future we may make this customizable to allow DP measurements. clipped_count_agg_factory = sum_factory.SumFactory() clipped_count_agg_process = clipped_count_agg_factory.create( computation_types.to_type(COUNT_TF_TYPE)) prefix = lambda s: attribute_prefix + s def init_fn_impl(inner_agg_process): state = collections.OrderedDict([ (prefix('ing_norm'), clipping_norm_process.initialize()), ('inner_agg', inner_agg_process.initialize()), (prefix('ed_count_agg'), clipped_count_agg_process.initialize()) ]) return intrinsics.federated_zip(state) def next_fn_impl(state, value, clip_fn, inner_agg_process, weight=None): clipping_norm_state, agg_state, clipped_count_state = state clipping_norm = clipping_norm_process.report(clipping_norm_state) clients_clipping_norm = intrinsics.federated_broadcast(clipping_norm) # TODO(b/163880757): Remove this when server-only metrics are supported. clipping_norm = intrinsics.federated_mean(clients_clipping_norm) clipped_value, global_norm, was_clipped = intrinsics.federated_map( clip_fn, (value, clients_clipping_norm)) new_clipping_norm_state = clipping_norm_process.next( clipping_norm_state, global_norm) if weight is None: agg_output = inner_agg_process.next(agg_state, clipped_value) else: agg_output = inner_agg_process.next(agg_state, clipped_value, weight) clipped_count_output = clipped_count_agg_process.next( clipped_count_state, was_clipped) new_state = collections.OrderedDict([ (prefix('ing_norm'), new_clipping_norm_state), ('inner_agg', agg_output.state), (prefix('ed_count_agg'), clipped_count_output.state) ]) measurements = collections.OrderedDict([ (prefix('ing'), agg_output.measurements), (prefix('ing_norm'), clipping_norm), (prefix('ed_count'), clipped_count_output.result) ]) return measured_process.MeasuredProcessOutput( state=intrinsics.federated_zip(new_state), result=agg_output.result, measurements=intrinsics.federated_zip(measurements)) if isinstance(inner_agg_factory, factory.WeightedAggregationFactory): class WeightedRobustFactory(factory.WeightedAggregationFactory): """`WeightedAggregationFactory` factory for clipping large values.""" def create( self, value_type: factory.ValueType, weight_type: factory.ValueType ) -> aggregation_process.AggregationProcess: _check_value_type(value_type) py_typecheck.check_type(weight_type, factory.ValueType.__args__) inner_agg_process = inner_agg_factory.create(value_type, weight_type) clip_fn = make_clip_fn(value_type) @computations.federated_computation() def init_fn(): return init_fn_impl(inner_agg_process) @computations.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type), computation_types.at_clients(weight_type)) def next_fn(state, value, weight): return next_fn_impl(state, value, clip_fn, inner_agg_process, weight) return aggregation_process.AggregationProcess(init_fn, next_fn) return WeightedRobustFactory() else: class UnweightedRobustFactory(factory.UnweightedAggregationFactory): """`UnweightedAggregationFactory` factory for clipping large values.""" def create( self, value_type: factory.ValueType ) -> aggregation_process.AggregationProcess: _check_value_type(value_type) inner_agg_process = inner_agg_factory.create(value_type) clip_fn = make_clip_fn(value_type) @computations.federated_computation() def init_fn(): return init_fn_impl(inner_agg_process) @computations.federated_computation( init_fn.type_signature.result, computation_types.at_clients(value_type)) def next_fn(state, value): return next_fn_impl(state, value, clip_fn, inner_agg_process) return aggregation_process.AggregationProcess(init_fn, next_fn) return UnweightedRobustFactory()
def build_federated_evaluation( model_fn: Callable[[], model_lib.Model], broadcast_process: Optional[measured_process.MeasuredProcess] = None, use_experimental_simulation_loop: bool = False, ) -> computation_base.Computation: """Builds the TFF computation for federated evaluation of the given model. Args: model_fn: A no-arg function that returns a `tff.learning.Model`. This method must *not* capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the model weights on the server to the clients. It must support the signature `(input_values@SERVER -> output_values@CLIENTS)` and have empty state. If set to default None, the server model is broadcast to the clients using the default tff.federated_broadcast. use_experimental_simulation_loop: Controls the reduce loop function for input dataset. An experimental reduce loop is used for simulation. Returns: A federated computation (an instance of `tff.Computation`) that accepts model parameters and federated data, and returns the evaluation metrics as aggregated by `tff.learning.Model.federated_output_computation`. """ if broadcast_process is not None: if not isinstance(broadcast_process, measured_process.MeasuredProcess): raise ValueError( '`broadcast_process` must be a `MeasuredProcess`, got ' f'{type(broadcast_process)}.') if optimizer_utils.is_stateful_process(broadcast_process): raise ValueError( 'Cannot create a federated evaluation with a stateful ' 'broadcast process, must be stateless, has state: ' f'{broadcast_process.initialize.type_signature.result!r}') # Construct the model first just to obtain the metadata and define all the # types needed to define the computations that follow. # TODO(b/124477628): Ideally replace the need for stamping throwaway models # with some other mechanism. with tf.Graph().as_default(): model = model_fn() model_weights_type = model_utils.weights_type_from_model(model) batch_type = computation_types.to_type(model.input_spec) @computations.tf_computation(model_weights_type, SequenceType(batch_type)) @tf.function def client_eval(incoming_model_weights, dataset): """Returns local outputs after evaluting `model_weights` on `dataset`.""" with tf.init_scope(): model = model_fn() model_weights = model_utils.ModelWeights.from_model(model) tf.nest.map_structure(lambda v, t: v.assign(t), model_weights, incoming_model_weights) def reduce_fn(num_examples, batch): model_output = model.forward_pass(batch, training=False) if model_output.num_examples is None: # Compute shape from the size of the predictions if model didn't use the # batch size. return num_examples + tf.shape(model_output.predictions, out_type=tf.int64)[0] else: return num_examples + tf.cast(model_output.num_examples, tf.int64) dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn( use_experimental_simulation_loop) num_examples = dataset_reduce_fn( reduce_fn=reduce_fn, dataset=dataset, initial_state_fn=lambda: tf.zeros([], dtype=tf.int64)) return collections.OrderedDict( local_outputs=model.report_local_outputs(), num_examples=num_examples) @computations.federated_computation( computation_types.at_server(model_weights_type), computation_types.at_clients(SequenceType(batch_type))) def server_eval(server_model_weights, federated_dataset): if broadcast_process is not None: # TODO(b/179091838): Zip the measurements from the broadcast_process with # the result of `model.federated_output_computation` below to avoid # dropping these metrics. broadcast_output = broadcast_process.next( broadcast_process.initialize(), server_model_weights) client_outputs = intrinsics.federated_map( client_eval, (broadcast_output.result, federated_dataset)) else: client_outputs = intrinsics.federated_map(client_eval, [ intrinsics.federated_broadcast(server_model_weights), federated_dataset ]) model_metrics = model.federated_output_computation( client_outputs.local_outputs) statistics = collections.OrderedDict( num_examples=intrinsics.federated_sum(client_outputs.num_examples)) return intrinsics.federated_zip( collections.OrderedDict(eval=model_metrics, stat=statistics)) return server_eval
def test_build_tf_computations_for_sum(self, encoder_constructor): # Tests that the partial computations have matching relevant input-output # signatures. value_spec = tf.TensorSpec((20, ), tf.float32) encoder = te.encoders.as_gather_encoder(encoder_constructor(), value_spec) _, state_type = encoding_utils._build_initial_state_tf_computation( encoder) value_type = computation_types.to_type(value_spec) nest_encoder = encoding_utils._build_tf_computations_for_gather( state_type, value_type, encoder) self.assertEqual(state_type, nest_encoder.get_params_fn.type_signature.parameter) encode_params_type = nest_encoder.get_params_fn.type_signature.result[ 0] decode_before_sum_params_type = nest_encoder.get_params_fn.type_signature.result[ 1] decode_after_sum_params_type = nest_encoder.get_params_fn.type_signature.result[ 2] self.assertEqual(value_type, nest_encoder.encode_fn.type_signature.parameter[0]) self.assertEqual(encode_params_type, nest_encoder.encode_fn.type_signature.parameter[1]) self.assertEqual(decode_before_sum_params_type, nest_encoder.encode_fn.type_signature.parameter[2]) state_update_tensors_type = nest_encoder.encode_fn.type_signature.result[ 2] accumulator_type = nest_encoder.zero_fn.type_signature.result self.assertEqual(state_update_tensors_type, accumulator_type.state_update_tensors) self.assertEqual( accumulator_type, nest_encoder.accumulate_fn.type_signature.parameter[0]) self.assertEqual( nest_encoder.encode_fn.type_signature.result, nest_encoder.accumulate_fn.type_signature.parameter[1]) self.assertEqual(accumulator_type, nest_encoder.accumulate_fn.type_signature.result) self.assertEqual(accumulator_type, nest_encoder.merge_fn.type_signature.parameter[0]) self.assertEqual(accumulator_type, nest_encoder.merge_fn.type_signature.parameter[1]) self.assertEqual(accumulator_type, nest_encoder.merge_fn.type_signature.result) self.assertEqual(accumulator_type, nest_encoder.report_fn.type_signature.parameter) self.assertEqual(accumulator_type, nest_encoder.report_fn.type_signature.result) self.assertEqual( accumulator_type.values, nest_encoder.decode_after_sum_fn.type_signature.parameter[0]) self.assertEqual( decode_after_sum_params_type, nest_encoder.decode_after_sum_fn.type_signature.parameter[1]) self.assertEqual( value_type, nest_encoder.decode_after_sum_fn.type_signature.result) self.assertEqual( state_type, nest_encoder.update_state_fn.type_signature.parameter[0]) self.assertEqual( state_update_tensors_type, nest_encoder.update_state_fn.type_signature.parameter[1]) self.assertEqual(state_type, nest_encoder.update_state_fn.type_signature.result)
def embed_tensorflow_computation(comp, type_spec=None, device=None): """Embeds a TensorFlow computation for use in the eager context. Args: comp: An instance of `pb.Computation`. type_spec: An optional `tff.Type` instance or something convertible to it. device: An optional `tf.config.LogicalDevice`. Returns: Either a one-argument or a zero-argument callable that executes the computation in eager mode. Raises: TypeError: If arguments are of the wrong types, e.g., in `comp` is not a TensorFlow computation. """ # TODO(b/134543154): Decide whether this belongs in `tensorflow_utils.py` # since it deals exclusively with eager mode. Incubate here, and potentially # move there, once stable. py_typecheck.check_type(comp, pb.Computation) comp = _ensure_comp_runtime_compatible(comp) comp_type = type_serialization.deserialize_type(comp.type) type_spec = computation_types.to_type(type_spec) if type_spec is not None: if not type_spec.is_equivalent_to(comp_type): raise TypeError('Expected a computation of type {}, got {}.'.format( type_spec, comp_type)) else: type_spec = comp_type # TODO(b/156302055): Currently, TF will raise on any function returning a # `tf.data.Dataset` not pinned to CPU. We should follow up here and remove # this gating when we can. must_pin_function_to_cpu = type_analysis.contains(type_spec.result, lambda t: t.is_sequence()) which_computation = comp.WhichOneof('computation') if which_computation != 'tensorflow': unexpected_building_block = building_blocks.ComputationBuildingBlock.from_proto( comp) raise TypeError('Expected a TensorFlow computation, found {}.'.format( unexpected_building_block)) if type_spec.is_function(): param_type = type_spec.parameter result_type = type_spec.result else: param_type = None result_type = type_spec wrapped_fn = _get_wrapped_function_from_comp(comp, must_pin_function_to_cpu, param_type, device) param_fns = [] if param_type is not None: for spec in structure.flatten(type_spec.parameter): if spec.is_tensor(): param_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) param_fns.append(tf.data.experimental.to_variant) result_fns = [] for spec in structure.flatten(result_type): if spec.is_tensor(): result_fns.append(lambda x: x) else: py_typecheck.check_type(spec, computation_types.SequenceType) tf_structure = type_conversions.type_to_tf_structure(spec.element) def fn(x, tf_structure=tf_structure): return tf.data.experimental.from_variant(x, tf_structure) result_fns.append(fn) ops = wrapped_fn.graph.get_operations() eager_cleanup_ops = [] destroy_before_invocation = [] for op in ops: if op.type == 'HashTableV2': eager_cleanup_ops += op.outputs if eager_cleanup_ops: for resource in wrapped_fn.prune(feeds={}, fetches=eager_cleanup_ops)(): destroy_before_invocation.append(resource) lazy_cleanup_ops = [] destroy_after_invocation = [] for op in ops: if op.type == 'VarHandleOp': lazy_cleanup_ops += op.outputs if lazy_cleanup_ops: for resource in wrapped_fn.prune(feeds={}, fetches=lazy_cleanup_ops)(): destroy_after_invocation.append(resource) def fn_to_return(arg, param_fns=tuple(param_fns), result_fns=tuple(result_fns), result_type=result_type, wrapped_fn=wrapped_fn, destroy_before=tuple(destroy_before_invocation), destroy_after=tuple(destroy_after_invocation)): # This double-function pattern works around python late binding, forcing the # variables to bind eagerly. return _call_embedded_tf( arg=arg, param_fns=param_fns, result_fns=result_fns, result_type=result_type, wrapped_fn=wrapped_fn, destroy_before_invocation=destroy_before, destroy_after_invocation=destroy_after) # pylint: disable=function-redefined if must_pin_function_to_cpu: old_fn_to_return = fn_to_return def fn_to_return(x): with tf.device('cpu'): return old_fn_to_return(x) elif device is not None: old_fn_to_return = fn_to_return def fn_to_return(x): with tf.device(device.name): return old_fn_to_return(x) # pylint: enable=function-redefined if param_type is not None: return lambda arg: fn_to_return(arg) # pylint: disable=unnecessary-lambda else: return lambda: fn_to_return(None)
async def create_value( self, value: Any, type_spec: Any = None) -> executor_value_base.ExecutorValue: """Creates an embedded value from the given `value` and `type_spec`. The kinds of supported `value`s are: * An instance of `intrinsic_defs.IntrinsicDef`. * An instance of `placements.PlacementLiteral`. * An instance of `pb.Computation` if of one of the following kinds: intrinsic, lambda, tensorflow, xla, or data. * A Python `list` if `type_spec` is a federated type. Note: The `value` must be a list even if it is of an `all_equal` type or if there is only a single participant associated with the given placement. * A Python value if `type_spec` is a non-functional, non-federated type. Args: value: An object to embed in the executor, one of the supported types defined by above. type_spec: An optional type convertible to instance of `tff.Type` via `tff.to_type`, the type of `value`. Returns: An instance of `executor_value_base.ExecutorValue` representing a value embedded in the `FederatingExecutor` using a particular `FederatingStrategy`. Raises: TypeError: If the `value` and `type_spec` do not match. ValueError: If `value` is not a kind supported by the `FederatingExecutor`. """ type_spec = computation_types.to_type(type_spec) if isinstance(value, intrinsic_defs.IntrinsicDef): type_analysis.check_concrete_instance_of(type_spec, value.type_signature) return self._strategy.ingest_value(value, type_spec) elif isinstance(value, placements.PlacementLiteral): if type_spec is None: type_spec = computation_types.PlacementType() type_spec.check_placement() return self._strategy.ingest_value(value, type_spec) elif isinstance(value, computation_impl.ConcreteComputation): return await self.create_value( computation_impl.ConcreteComputation.get_proto(value), executor_utils.reconcile_value_with_type_spec(value, type_spec)) elif isinstance(value, pb.Computation): deserialized_type = type_serialization.deserialize_type(value.type) if type_spec is None: type_spec = deserialized_type else: type_spec.check_assignable_from(deserialized_type) which_computation = value.WhichOneof('computation') if which_computation in ['lambda', 'tensorflow', 'xla', 'data']: return self._strategy.ingest_value(value, type_spec) elif which_computation == 'intrinsic': if value.intrinsic.uri in FederatingExecutor._FORWARDED_INTRINSICS: return self._strategy.ingest_value(value, type_spec) intrinsic_def = intrinsic_defs.uri_to_intrinsic_def(value.intrinsic.uri) if intrinsic_def is None: raise ValueError('Encountered an unrecognized intrinsic "{}".'.format( value.intrinsic.uri)) return await self.create_value(intrinsic_def, type_spec) else: raise ValueError( 'Unsupported computation building block of type "{}".'.format( which_computation)) elif type_spec is not None and type_spec.is_federated(): return await self._strategy.compute_federated_value(value, type_spec) else: result = await self._unplaced_executor.create_value(value, type_spec) return self._strategy.ingest_value(result, type_spec)
def save_functional_model(functional_model: functional.FunctionalModel, path: str): """Serializes a `FunctionalModel` as a `tf.SavedModel` to `path`. Args: functional_model: A `tff.learning.models.FunctionalModel`. path: A `str` directory path to serialize the model to. """ m = tf.Module() # Serialize the initial_weights values as a tf.function that creates a # structure of tensors with the initial weights. This way we can add it to the # tf.SavedModel and call it to create initial weights after deserialization. create_initial_weights = lambda: functional_model.initial_weights with tf.Graph().as_default(): concrete_structured_fn = tf.function( create_initial_weights).get_concrete_function() model_weights_tensor_specs = tf.nest.map_structure( tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs) initial_weights_result_type_spec = type_serialization.serialize_type( computation_types.to_type(model_weights_tensor_specs)) m.create_initial_weights_type_spec = tf.Variable( initial_weights_result_type_spec.SerializeToString(deterministic=True)) def flat_initial_weights(): return tf.nest.flatten(create_initial_weights()) with tf.Graph().as_default(): m.create_initial_weights = tf.function( flat_initial_weights).get_concrete_function() # Serialize forward pass concretely, once for training and once for # non-training. # TODO(b/198150431): try making `training` a `tf.Tensor` parameter to remove # the need to for serializing two different function graphs. def make_concrete_flat_forward_pass(training: bool): """Create a concrete forward_pass function that has flattened output. Args: training: A boolean indicating whether this is a call in a training loop, or evaluation loop. Returns: A 2-tuple of concrete `tf.function` instance and a `tff.Type` protocol buffer message documenting the the result structure returned by the concrete function. """ # Save the un-flattened type spec for deserialization later. # Note: `training` is a Python boolean, which gets "curried", in a sense, # during function conretization. The resulting concrete function only has # parameters for `model_weights` and `batch_input`, which are # `tf.TensorSpec` structures here. with tf.Graph().as_default(): concrete_structured_fn = functional_model.forward_pass.get_concrete_function( model_weights_tensor_specs, functional_model.input_spec, # Note: training does not appear in the resulting concrete function. training=training) output_tensor_spec_structure = tf.nest.map_structure( tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs) result_type_spec = type_serialization.serialize_type( computation_types.to_type(output_tensor_spec_structure)) @tf.function def flat_forward_pass(model_weights, batch_input, training): return tf.nest.flatten( functional_model.forward_pass(model_weights, batch_input, training)) with tf.Graph().as_default(): flat_concrete_fn = flat_forward_pass.get_concrete_function( model_weights_tensor_specs, functional_model.input_spec, # Note: training does not appear in the resulting concrete function. training=training) return flat_concrete_fn, result_type_spec fw_pass_training, fw_pass_training_type_spec = make_concrete_flat_forward_pass( training=True) m.flat_forward_pass_training = fw_pass_training m.forward_pass_training_type_spec = tf.Variable( fw_pass_training_type_spec.SerializeToString(deterministic=True), trainable=False) fw_pass_inference, fw_pass_inference_type_spec = make_concrete_flat_forward_pass( training=False) m.flat_forward_pass_inference = fw_pass_inference m.forward_pass_inference_type_spec = tf.Variable( fw_pass_inference_type_spec.SerializeToString(deterministic=True), trainable=False) # Serialize predict_on_batch, once for training, once for non-training. x_type = functional_model.input_spec[0] # TODO(b/198150431): try making `training` a `tf.Tensor` parameter to remove # the need to for serializing two different function graphs. def make_concrete_flat_predict_on_batch(training: bool): """Create a concrete predict_on_batch function that has flattened output. Args: training: A boolean indicating whether this is a call in a training loop, or evaluation loop. Returns: A 2-tuple of concrete `tf.function` instance and a `tff.Type` protocol buffer message documenting the the result structure returned by the concrete function. """ # Save the un-flattened type spec for deserialization later. # Note: `training` is a Python boolean, which gets "curried", in a sense, # during function conretization. The resulting concrete function only has # parameters for `model_weights` and `batch_input`, which are # `tf.TensorSpec` structures here. concrete_structured_fn = tf.function( functional_model.predict_on_batch ).get_concrete_function( model_weights_tensor_specs, x_type, # Note: training does not appear in the resulting concrete function. training=training) output_tensor_spec_structure = tf.nest.map_structure( tf.TensorSpec.from_tensor, concrete_structured_fn.structured_outputs) result_type_spec = type_serialization.serialize_type( computation_types.to_type(output_tensor_spec_structure)) @tf.function def flat_predict_on_batch(model_weights, x, training): return tf.nest.flatten( functional_model.predict_on_batch(model_weights, x, training)) flat_concrete_fn = tf.function( flat_predict_on_batch ).get_concrete_function( model_weights_tensor_specs, x_type, # Note: training does not appear in the resulting concrete function. training=training) return flat_concrete_fn, result_type_spec with tf.Graph().as_default(): predict_training, predict_training_type_spec = make_concrete_flat_predict_on_batch( training=True) m.predict_on_batch_training = predict_training m.predict_on_batch_training_type_spec = tf.Variable( predict_training_type_spec.SerializeToString(deterministic=True), trainable=False) with tf.Graph().as_default(): predict_inference, predict_inference_type_spec = make_concrete_flat_predict_on_batch( training=False) m.predict_on_batch_inference = predict_inference m.predict_on_batch_inference_type_spec = tf.Variable( predict_inference_type_spec.SerializeToString(deterministic=True), trainable=False) # Serialize TFF values as string variables that contain the serialized # protos from the computation or the type. m.serialized_input_spec = tf.Variable(type_serialization.serialize_type( computation_types.to_type( functional_model.input_spec)).SerializeToString( deterministic=True), trainable=False) # Save everything _save_tensorflow_module(m, path)
def _extract_intrinsics_to_top_level_lambda(comp, uri): r"""Extracts intrinsics in `comp` for the given `uri`. This transformation creates an AST such that all the called intrinsics for the given `uri` in body of the `building_blocks.Block` returned by the top level lambda have been extracted to the top level lambda and replaced by selections from a reference to the constructed variable. Lambda | Block / \ [x=Struct, ...] Comp | [Call, Call Call] / \ / \ / \ Intrinsic Comp Intrinsic Comp Intrinsic Comp The order of the extracted called intrinsics matches the order of `uri`. Note: if this function is passed an AST which contains nested called intrinsics, it will fail, as it will mutate the subcomputation containing the lower-level called intrinsics on the way back up the tree. Args: comp: The `building_blocks.Lambda` to transform. The names of lambda parameters and block variables in `comp` must be unique. uri: A URI of an intrinsic. Returns: A new computation with the transformation applied or the original `comp`. Raises: ValueError: If all the intrinsics for the given `uri` in `comp` are not exclusively bound by `comp`. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(uri, list) for x in uri: py_typecheck.check_type(x, str) tree_analysis.check_has_unique_names(comp) name_generator = building_block_factory.unique_name_generator(comp) intrinsics = _get_called_intrinsics(comp, uri) for intrinsic in intrinsics: if not tree_analysis.contains_no_unbound_references( intrinsic, comp.parameter_name): raise ValueError( 'Expected a computation which binds all the references in all the ' 'intrinsic with the uri: {}.'.format(uri)) if len(intrinsics) > 1: order = {} for index, element in enumerate(uri): if element not in order: order[element] = index intrinsics = sorted(intrinsics, key=lambda x: order[x.function.uri]) extracted_comp = building_blocks.Struct(intrinsics) else: extracted_comp = intrinsics[0] ref_name = next(name_generator) ref_type = computation_types.to_type(extracted_comp.type_signature) ref = building_blocks.Reference(ref_name, ref_type) def _should_transform(comp): return building_block_analysis.is_called_intrinsic(comp, uri) def _transform(comp): if not _should_transform(comp): return comp, False if len(intrinsics) > 1: index = intrinsics.index(comp) comp = building_blocks.Selection(ref, index=index) return comp, True else: return ref, True comp, _ = transformation_utils.transform_postorder(comp, _transform) comp = _insert_comp_in_top_level_lambda(comp, name=ref.name, comp_to_insert=extracted_comp) return comp, True
def create_whimsy_computation_tensorflow_identity(arg_type=tf.float32): """Returns a tensorflow computation and type `(float32 -> float32)`.""" value, type_signature = tensorflow_computation_factory.create_identity( computation_types.to_type(arg_type)) return value, type_signature
def _group_by_intrinsics_in_top_level_lambda(comp): """Groups the intrinsics in the frist block local in the result of `comp`. This transformation creates an AST by replacing the tuple of called intrinsics found as the first local in the `building_blocks.Block` returned by the top level lambda with two new computations. The first computation is a tuple of tuples of called intrinsics, representing the original tuple of called intrinscis grouped by URI. The second computation is a tuple of selection from the first computations, representing original tuple of called intrinsics. It is necessary to group intrinsics before it is possible to merge them. Args: comp: The `building_blocks.Lambda` to transform. Returns: A `building_blocks.Lamda` that returns a `building_blocks.Block`, the first local variables of the retunred `building_blocks.Block` will be a tuple of tuples of called intrinsics representing the original tuple of called intrinscis grouped by URI. Raises: ValueError: If the first local in the `building_blocks.Block` referenced by the top level lambda is not a `building_blocks.Struct` of called intrinsics. """ py_typecheck.check_type(comp, building_blocks.Lambda) py_typecheck.check_type(comp.result, building_blocks.Block) tree_analysis.check_has_unique_names(comp) name_generator = building_block_factory.unique_name_generator(comp) name, first_local = comp.result.locals[0] py_typecheck.check_type(first_local, building_blocks.Struct) for element in first_local: if not building_block_analysis.is_called_intrinsic(element): raise ValueError( 'Expected all the elements of the `building_blocks.Struct` to be ' 'called intrinsics, but found: \n{}'.format(element)) # Create collections of data describing how to pack and unpack the intrinsics # into groups by their URI. # # packed_keys is a list of unique URI ordered by occurrence in the original # tuple of called intrinsics. # packed_groups is a `collections.OrderedDict` where each key is a URI to # group by and each value is a list of intrinsics with that URI. # packed_indexes is a list of tuples where each tuple contains two indexes: # the first index in the tuple is the index of the group that the intrinsic # was packed into; the second index in the tuple is the index of the # intrinsic in that group that the intrinsic was packed into; the index of # the tuple in packed_indexes corresponds to the index of the intrinsic in # the list of intrinsics that are beging grouped. Therefore, packed_indexes # represents an implicit mapping of packed indexes, keyed by unpacked index. packed_keys = [] for called_intrinsic in first_local: uri = called_intrinsic.function.uri if uri not in packed_keys: packed_keys.append(uri) # If there are no duplicates, return early. if len(packed_keys) == len(first_local): return comp, False packed_groups = collections.OrderedDict([(x, []) for x in packed_keys]) packed_indexes = [] for called_intrinsic in first_local: packed_group = packed_groups[called_intrinsic.function.uri] packed_group.append(called_intrinsic) packed_indexes.append(( packed_keys.index(called_intrinsic.function.uri), len(packed_group) - 1, )) packed_elements = [] for called_intrinsics in packed_groups.values(): if len(called_intrinsics) > 1: element = building_blocks.Struct(called_intrinsics) else: element = called_intrinsics[0] packed_elements.append(element) packed_comp = building_blocks.Struct(packed_elements) packed_ref_name = next(name_generator) packed_ref_type = computation_types.to_type(packed_comp.type_signature) packed_ref = building_blocks.Reference(packed_ref_name, packed_ref_type) unpacked_elements = [] for indexes in packed_indexes: group_index = indexes[0] sel = building_blocks.Selection(packed_ref, index=group_index) uri = packed_keys[group_index] called_intrinsics = packed_groups[uri] if len(called_intrinsics) > 1: intrinsic_index = indexes[1] sel = building_blocks.Selection(sel, index=intrinsic_index) unpacked_elements.append(sel) unpacked_comp = building_blocks.Struct(unpacked_elements) variables = comp.result.locals variables[0] = (name, unpacked_comp) variables.insert(0, (packed_ref_name, packed_comp)) block = building_blocks.Block(variables, comp.result.result) fn = building_blocks.Lambda(comp.parameter_name, comp.parameter_type, block) return fn, True
def test_init_does_not_raise_type_error_with_unknown_dimensions(self): server_state_type = computation_types.TensorType(shape=[None], dtype=tf.int32) @tensorflow_computation.tf_computation def initialize(): # Return a value of a type assignable to, but not equal to # `server_state_type` return tf.constant([1, 2, 3]) @tensorflow_computation.tf_computation(server_state_type) def prepare(server_state): del server_state # Unused return tf.constant(1.0) @tensorflow_computation.tf_computation( computation_types.SequenceType(tf.float32), tf.float32) def work(client_data, client_input): del client_data # Unused del client_input # Unused return True, [], [], [] @tensorflow_computation.tf_computation def zero(): return tf.constant([], dtype=tf.string) @tensorflow_computation.tf_computation( computation_types.TensorType(shape=[None], dtype=tf.string), tf.bool) def accumulate(accumulator, client_update): del accumulator # Unused del client_update # Unused return tf.constant(['abc']) @tensorflow_computation.tf_computation( computation_types.TensorType(shape=[None], dtype=tf.string), computation_types.TensorType(shape=[None], dtype=tf.string)) def merge(accumulator1, accumulator2): del accumulator1 # Unused del accumulator2 # Unused return tf.constant(['abc']) @tensorflow_computation.tf_computation( computation_types.TensorType(shape=[None], dtype=tf.string)) def report(accumulator): del accumulator # Unused return tf.constant(1.0) unit_comp = tensorflow_computation.tf_computation(lambda: []) bitwidth = unit_comp max_input = unit_comp modulus = unit_comp unit_type = computation_types.to_type([]) @tensorflow_computation.tf_computation( server_state_type, (tf.float32, unit_type, unit_type, unit_type)) def update(server_state, global_update): del server_state # Unused del global_update # Unused # Return a new server state value whose type is assignable but not equal # to `server_state_type`, and which is different from the type returned # by `initialize`. return tf.constant([1]), [] try: forms.MapReduceForm(initialize, prepare, work, zero, accumulate, merge, report, bitwidth, max_input, modulus, update) except TypeError: self.fail('Raised TypeError unexpectedly.')
def pack_args_into_struct( args: Sequence[Any], kwargs: Mapping[str, Any], type_spec=None, context: Optional[context_base.Context] = None) -> structure.Struct: """Packs positional and keyword arguments into a `Struct`. If 'type_spec' is not None, it must be a `StructType` or something that's convertible to it by computation_types.to_type(). The assignment of arguments to fields of the struct follows the same rule as during function calls. If 'type_spec' is None, the positional arguments precede any of the keyword arguments, and the ordering of the keyword arguments matches the ordering in which they appear in kwargs. If the latter is an OrderedDict, the ordering will be preserved. On the other hand, if the latter is an ordinary unordered dict, the ordering is arbitrary. Args: args: Positional arguments. kwargs: Keyword arguments. type_spec: The optional type specification (either an instance of `computation_types.StructType` or something convertible to it), or None if there's no type. Used to drive the arrangements of args into fields of the constructed struct, as noted in the description. context: The optional context (an instance of `context_base.Context`) in which the arguments are being packed. Required if and only if the `type_spec` is not `None`. Returns: An struct containing all the arguments. Raises: TypeError: if the arguments are of the wrong computation_types. """ type_spec = computation_types.to_type(type_spec) if not type_spec: return structure.Struct([(None, arg) for arg in args] + list(kwargs.items())) else: py_typecheck.check_type(type_spec, computation_types.StructType) py_typecheck.check_type(context, context_base.Context) context = typing.cast(context_base.Context, context) if not is_argument_struct(type_spec): # pylint: disable=attribute-error raise TypeError( 'Parameter type {} does not have a structure of an argument struct, ' 'and cannot be populated from multiple positional and keyword ' 'arguments'.format(type_spec)) else: result_elements = [] positions_used = set() keywords_used = set() for index, (name, elem_type) in enumerate( structure.to_elements(type_spec)): if index < len(args): # This argument is present in `args`. if name is not None and name in kwargs: raise TypeError( 'Argument `{}` specified twice.'.format(name)) else: arg_value = args[index] result_elements.append( (name, context.ingest(arg_value, elem_type))) positions_used.add(index) elif name is not None and name in kwargs: # This argument is present in `kwargs`. arg_value = kwargs[name] result_elements.append( (name, context.ingest(arg_value, elem_type))) keywords_used.add(name) elif name: raise TypeError( f'Missing argument `{name}` of type {elem_type}.') else: raise TypeError( f'Missing argument of type {elem_type} at position {index}.' ) positions_missing = set(range( len(args))).difference(positions_used) if positions_missing: raise TypeError( f'Positional arguments at {positions_missing} not used.') keywords_missing = set(kwargs.keys()).difference(keywords_used) if keywords_missing: raise TypeError( f'Keyword arguments at {keywords_missing} not used.') return structure.Struct(result_elements)
from tensorflow_federated.python.core.impl.tensorflow_context import tensorflow_computation from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.core.impl.types import placements from tensorflow_federated.python.core.templates import measured_process from tensorflow_federated.python.learning import keras_utils from tensorflow_federated.python.learning import model_examples from tensorflow_federated.python.learning import model_utils from tensorflow_federated.python.learning.optimizers import sgdm from tensorflow_federated.python.learning.templates import client_works from tensorflow_federated.python.learning.templates import composers from tensorflow_federated.python.learning.templates import distributors from tensorflow_federated.python.learning.templates import finalizers from tensorflow_federated.python.learning.templates import learning_process FLOAT_TYPE = computation_types.TensorType(tf.float32) MODEL_WEIGHTS_TYPE = computation_types.to_type( model_utils.ModelWeights(FLOAT_TYPE, ())) CLIENTS_SEQUENCE_FLOAT_TYPE = computation_types.at_clients( computation_types.SequenceType(FLOAT_TYPE)) def empty_at_server(): return intrinsics.federated_value((), placements.SERVER) @federated_computation.federated_computation() def empty_init_fn(): return empty_at_server() @tensorflow_computation.tf_computation() def test_init_model_weights_fn():
class SecureModularSumFactoryComputationTest(tf.test.TestCase, parameterized.TestCase): @parameterized.named_parameters( ('scalar_non_symmetric_int32', 8, tf.int32, False), ('scalar_non_symmetric_int64', 8, tf.int64, False), ('struct_non_symmetric', 8, _test_struct_type(tf.int32), False), ('scalar_symmetric_int32', 8, tf.int32, True), ('scalar_symmetric_int64', 8, tf.int64, True), ('struct_symmetric', 8, _test_struct_type(tf.int32), True), ('numpy_modulus_non_symmetric', np.int32(8), tf.int32, False), ('numpy_modulus_symmetric', np.int32(8), tf.int32, True), ) def test_type_properties(self, modulus, value_type, symmetric_range): factory_ = secure.SecureModularSumFactory( modulus=modulus, symmetric_range=symmetric_range) self.assertIsInstance(factory_, factory.UnweightedAggregationFactory) value_type = computation_types.to_type(value_type) process = factory_.create(value_type) self.assertIsInstance(process, aggregation_process.AggregationProcess) expected_state_type = computation_types.at_server( computation_types.to_type(())) expected_measurements_type = expected_state_type expected_initialize_type = computation_types.FunctionType( parameter=None, result=expected_state_type) self.assertTrue( process.initialize.type_signature.is_equivalent_to( expected_initialize_type)) expected_next_type = computation_types.FunctionType( parameter=collections.OrderedDict( state=expected_state_type, value=computation_types.at_clients(value_type)), result=measured_process.MeasuredProcessOutput( state=expected_state_type, result=computation_types.at_server(value_type), measurements=expected_measurements_type)) self.assertTrue( process.next.type_signature.is_equivalent_to(expected_next_type)) try: static_assert.assert_not_contains_unsecure_aggregation( process.next) except: # pylint: disable=bare-except self.fail('Factory returned an AggregationProcess containing ' 'non-secure aggregation.') def test_float_modulus_raises(self): with self.assertRaises(TypeError): secure.SecureModularSumFactory(modulus=8.0) with self.assertRaises(TypeError): secure.SecureModularSumFactory(modulus=np.float32(8.0)) def test_modulus_not_positive_raises(self): with self.assertRaises(ValueError): secure.SecureModularSumFactory(modulus=0) with self.assertRaises(ValueError): secure.SecureModularSumFactory(modulus=-1) def test_symmetric_range_not_bool_raises(self): with self.assertRaises(TypeError): secure.SecureModularSumFactory(modulus=8, symmetric_range='True') @parameterized.named_parameters( ('float_type', computation_types.TensorType(tf.float32)), ('mixed_type', computation_types.to_type([tf.float32, tf.int32])), ('federated_type', computation_types.FederatedType(tf.int32, placements.SERVER)), ('function_type', computation_types.FunctionType(None, ())), ('sequence_type', computation_types.SequenceType(tf.float32))) def test_incorrect_value_type_raises(self, bad_value_type): with self.assertRaises(TypeError): secure.SecureModularSumFactory(8).create(bad_value_type)