def assert_less_equal_max_and_add(summation_and_max_input, summand): summation, original_max_input = summation_and_max_input if max_input_type.is_struct(): max_input = original_max_input else: # Broadcast max_input to the same structure as the summand. max_input = structure.map_structure( lambda *args: original_max_input, summand) # Assert that all coordinates in all tensors are less than the secure sum # allowed max input value. def assert_all_coordinates_less_equal(x, m): return tf.Assert( tf.reduce_all( tf.less_equal(tf.cast(x, tf.int64), tf.cast(m, tf.int64))), [ 'client value larger than maximum specified for secure sum', x, 'not less than or equal to', m ]) assert_ops = structure.flatten( structure.map_structure(assert_all_coordinates_less_equal, summand, max_input)) with tf.control_dependencies(assert_ops): return structure.map_structure(tf.add, summation, summand), original_max_input
def zeros_fn(): if member_type.is_struct(): structure.map_structure(lambda v: _validate_dtype_is_numeric(v.dtype), member_type) return structure.map_structure( lambda v: tf.fill(v.shape, value=initial_value_fn(v)), member_type) _validate_dtype_is_numeric(member_type.dtype) return tf.fill(member_type.shape, value=initial_value_fn(member_type))
def test_map_structure_tensors(self): x = tf.constant(1) y = tf.constant(2) self.assertAllEqual(structure.map_structure(tf.add, x, y), 3) x = tf.strings.bytes_split('abc') y = tf.strings.bytes_split('xyz') self.assertAllEqual( structure.map_structure(tf.add, x, y), ['ax', 'by', 'cz'])
async def create_struct(self, elements): target_val = await self._target.create_struct( structure.map_structure(lambda x: x.value, elements)) wrapped_val = TracingExecutorValue(self, self._get_new_value_index(), target_val) self._trace.append( ('create_struct', structure.map_structure(lambda x: x.index, elements), wrapped_val.index)) return wrapped_val
def test_map_structure_fails_different_structures(self): x = structure.Struct.named(a=10, c=20) y = structure.Struct.named(a=30) with self.assertRaises(TypeError): structure.map_structure(tf.add, x, y) x = structure.Struct.named(a=10) y = structure.Struct.named(a=30, c=tf.strings.bytes_split('abc')) with self.assertRaises(TypeError): structure.map_structure(tf.add, x, y)
def test_map_structure_tensor_fails(self): x = structure.Struct.named(a=10, c=20) y = tf.constant(2) with self.assertRaises(TypeError): structure.map_structure(tf.add, x, y) x = structure.Struct.named(a='abc', c='xyz') y = tf.strings.bytes_split('abc') with self.assertRaises(TypeError): structure.map_structure(tf.add, x, y)
class CreateUnaryOperatorTest(parameterized.TestCase, tf.test.TestCase): @parameterized.named_parameters( ('abs_int', tf.math.abs, _TensorType(tf.int32), [-1], 1), ('abs_float', tf.math.abs, _TensorType(tf.float32), [-1.0], 1.0), ('abs_unnamed_tuple', lambda x: structure.map_structure(tf.math.abs, x), _StructType( [_TensorType(tf.int32, [2]), _TensorType(tf.float32, [2])]), [[-1, -2], [-3.0, -4.0]], structure.Struct([(None, [1, 2]), (None, [3.0, 4.0])])), ('abs_named_tuple', lambda x: structure.map_structure(tf.math.abs, x), _StructType([('a', _TensorType(tf.int32, [2])), ('b', _TensorType(tf.float32, [2]))]), [ [-1, -2], [-3.0, -4.0] ], structure.Struct([('a', [1, 2]), ('b', [3.0, 4.0])])), ('reduce_sum_int', tf.math.reduce_sum, _TensorType(tf.int32, [2]), [2, 2], 4), ('reduce_sum_float', tf.math.reduce_sum, _TensorType( tf.float32, [2]), [2.0, 2.5], 4.5), ('log_inf', tf.math.log, _TensorType(tf.float32), [0.0], -np.inf), ) # pyformat: enable def test_returns_computation(self, operator, operand_type, operand, expected_result): proto, _ = tensorflow_computation_factory.create_unary_operator( operator, operand_type) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) self.assertIsInstance(actual_type, computation_types.FunctionType) # Note: It is only useful to test the parameter type; the result type # depends on the `operator` used, not the implemenation # `create_unary_operator`. expected_parameter_type = operand_type self.assertEqual(actual_type.parameter, expected_parameter_type) actual_result = test_utils.run_tensorflow(proto, operand) self.assertAllEqual(actual_result, expected_result) @parameterized.named_parameters( ('non_callable_operator', 1, _TensorType(tf.int32)), ('none_type', tf.math.add, None), ('federated_type', tf.math.add, computation_types.at_server(tf.int32)), ('sequence_type', tf.math.add, computation_types.SequenceType( tf.int32)), ) def test_raises_type_error(self, operator, type_signature): with self.assertRaises(TypeError): tensorflow_computation_factory.create_unary_operator( operator, type_signature)
def test_input_spec_struct(self): keras_model = model_examples.build_linear_regression_keras_functional_model( feature_dims=1) input_spec = computation_types.StructType( collections.OrderedDict( x=tf.TensorSpec(shape=[None, 1], dtype=tf.float32), y=tf.TensorSpec(shape=[None, 1], dtype=tf.float32))) tff_model = keras_utils.from_keras_model( keras_model=keras_model, input_spec=input_spec, loss=tf.keras.losses.MeanSquaredError()) self.assertIsInstance(tff_model, model_utils.EnhancedModel) self.assertIsInstance(tff_model.input_spec, structure.Struct) structure.map_structure(lambda x: self.assertIsInstance(x, tf.TensorSpec), tff_model.input_spec)
def test_map_structure(self): x = structure.Struct.named( a=10, b=structure.Struct.named( x=structure.Struct.named(p=40), y=30, z=structure.Struct.named(q=50, r=60)), c=20) y = structure.Struct.named( a=1, b=structure.Struct.named( x=structure.Struct.named(p=4), y=3, z=structure.Struct.named(q=5, r=6)), c=2) add = lambda v1, v2: v1 + v2 self.assertEqual( structure.map_structure(add, x, y), structure.Struct.named( a=11, b=structure.Struct.named( x=structure.Struct.named(p=44), y=33, z=structure.Struct.named(q=55, r=66)), c=22))
def identity(source): """Applies `tf.identity` pointwise to `source`. This utility function provides the exact same behavior as `tf.identity`, but it generalizes to a wider class of objects, including ordinary tensors, variables, as well as various types of nested structures. It would typically be used together with `tf.control_dependencies` in non-eager TensorFlow. Args: source: A nested structure composed of tensors or variables embedded in containers that are compatible with `tf.nest`, or instances of `structure.Struct`. Elements that represent variables have their content extracted prior to identity mapping by first invoking `tf.Variable.read_value`. Returns: The result of applying `tf.identity` to read all elements of the `source` pointwise, with the same structure as `source`. Raises: TypeError: If types mismatch. """ def _mapping_fn(x): if not tf.is_tensor(x): raise TypeError('Expected a tensor, found {}.'.format( py_typecheck.type_string(type(x)))) if hasattr(x, 'read_value'): x = x.read_value() return tf.identity(x) # TODO(b/113112108): Extend this to containers of mixed types. if isinstance(source, structure.Struct): return structure.map_structure(_mapping_fn, source) else: return tf.nest.map_structure(_mapping_fn, source)
def serialize_jax_computation(traced_fn, arg_fn, parameter_type, context_stack): """Serializes a Python function containing JAX code as a TFF computation. Args: traced_fn: The Python function containing JAX code to be traced by JAX and serialized as a TFF computation containing XLA code. arg_fn: An unpacking function that takes a TFF argument, and returns a combo of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed by `function_utils.create_argument_unpacking_fn`). parameter_type: An instance of `computation_types.Type` that represents the TFF type of the computation parameter, or `None` if the function does not take any parameters. context_stack: The context stack to use during serialization. Returns: An instance of `pb.Computation` with the constructed computation. Raises: TypeError: if the arguments are of the wrong types. """ py_typecheck.check_callable(traced_fn) py_typecheck.check_callable(arg_fn) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if parameter_type is not None: parameter_type = computation_types.to_type(parameter_type) packed_arg = _tff_type_to_xla_serializer_arg(parameter_type) else: packed_arg = None args, kwargs = arg_fn(packed_arg) # While the fake parameters are fed via args/kwargs during serialization, # it is possible for them to get reordered in the actual generated XLA code. # We use here the same flattening function as that one, which is used by # the JAX serializer to determine the ordering and allow it to be captured # in the parameter binding. We do not need to do anything special for the # results, since the results, if multiple, are always returned as a tuple. flattened_obj, _ = jax.tree_util.tree_flatten((args, kwargs)) tensor_indexes = list(np.argsort([x.tensor_index for x in flattened_obj])) context = jax_computation_context.JaxComputationContext() with context_stack.install(context): tracer_callable = jax.xla_computation( traced_fn, tuple_args=True, return_shape=True) compiled_xla, returned_shape = tracer_callable(*args, **kwargs) if isinstance(returned_shape, jax.ShapeDtypeStruct): returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor(returned_shape) else: returned_type_spec = computation_types.to_type( structure.map_structure( _jax_shape_dtype_struct_to_tff_tensor, structure.from_container(returned_shape, recursive=True))) computation_type = computation_types.FunctionType(parameter_type, returned_type_spec) return xla_serialization.create_xla_tff_computation(compiled_xla, tensor_indexes, computation_type)
def _get_accumulator_type(member_type): """Constructs a `tff.Type` for the accumulator in sample aggregation. Args: member_type: A `tff.Type` representing the member components of the federated type. Returns: The `tff.StructType` associated with the accumulator. The tuple contains two parts, `accumulators` and `rands`, that are parallel lists (e.g. the i-th index in one corresponds to the i-th index in the other). These two lists are used to sample from the accumulators with equal probability. """ # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed. if member_type.is_struct(): a = structure.map_structure( lambda v: computation_types.TensorType(v.dtype, [None] + v.shape.dims), member_type) return computation_types.StructType( collections.OrderedDict({ 'accumulators': computation_types.StructType(structure.to_odict(a, True)), 'rands': computation_types.TensorType(tf.float32, shape=[None]), })) return computation_types.StructType( collections.OrderedDict({ 'accumulators': computation_types.TensorType( member_type.dtype, shape=[None] + member_type.shape.dims), 'rands': computation_types.TensorType(tf.float32, shape=[None]), }))
def _remove_batch_dim( type_spec: computation_types.Type) -> computation_types.Type: """Removes the batch dimension from the `tff.TensorType`s in `type_spec`. Args: type_spec: A `tff.Type` containing `tff.TensorType`s as leaves. The first dimension in the leaf `tff.TensorType` is the batch dimension. Returns: A `tff.Type` of the same structure as `type_spec`, with no batch dimensions in all the leaf `tff.TensorType`s. Raises: TypeError: If the argument has the wrong type. ValueError: If the `tff.TensorType` does not have the first dimension. """ def _remove_first_dim_in_tensortype(tensor_type): """Return a new `tff.TensorType` after removing the first dimension.""" py_typecheck.check_type(tensor_type, computation_types.TensorType) if (tensor_type.shape.rank is not None) and (tensor_type.shape.rank >= 1): return computation_types.TensorType( shape=tensor_type.shape[1:], dtype=tensor_type.dtype) else: raise ValueError('Provided shape must have rank 1 or higher.') return structure.map_structure(_remove_first_dim_in_tensortype, type_spec)
def assign(target, source): """Creates an op that assigns `target` from `source`. This utility function provides the exact same behavior as `tf.Variable.assign`, but it generalizes to a wider class of objects, including ordinary variables as well as various types of nested structures. Args: target: A nested structure composed of variables embedded in containers that are compatible with `tf.nest`, or instances of `structure.Struct`. source: A nsested structure composed of tensors, matching that of `target`. Returns: A single op that represents the assignment. Raises: TypeError: If types mismatch. """ # TODO(b/113112108): Extend this to containers of mixed types. if isinstance(target, structure.Struct): return tf.group(*structure.flatten( structure.map_structure(lambda a, b: a.assign(b), target, source))) else: return tf.group(*tf.nest.flatten( tf.nest.map_structure(lambda a, b: a.assign(b), target, source)))
def _compute_summation_type_for_bitwidth(bitwidth, type_spec): """Creates a `tff.Type` with dtype based on bitwidth.""" def type_for_bitwidth_limited_tensor(bits, tensor_type): if bits < 1 or bits > MAXIMUM_SUPPORTED_BITWIDTH: raise ValueError( 'Encountered an bitwidth that cannot be handled: {b}. ' 'Extended bitwidth must be between [1,{m}].' '\nRequested: {r}'.format(b=bits, r=bitwidth, m=MAXIMUM_SUPPORTED_BITWIDTH)) elif bits < 32: return computation_types.TensorType( shape=tensor_type.shape, dtype=tf.uint32 if tensor_type.dtype.is_unsigned else tf.int32) else: return computation_types.TensorType( shape=tensor_type.shape, dtype=tf.uint64 if tensor_type.dtype.is_unsigned else tf.int64) if type_spec.is_tensor(): return type_for_bitwidth_limited_tensor(bitwidth, type_spec) elif type_spec.is_struct(): return computation_types.StructType( structure.iter_elements( structure.map_structure(type_for_bitwidth_limited_tensor, bitwidth, type_spec))) else: raise TypeError( 'Summation types can only be created from TensorType or ' 'StructType. Received a {t}'.format(t=type_spec))
def test_map_structure(self): x = structure.Struct([ ('a', 10), ('b', structure.Struct([ ('x', structure.Struct([('p', 40)])), ('y', 30), ('z', structure.Struct([('q', 50), ('r', 60)])), ])), ('c', 20), ]) y = structure.Struct([ ('a', 1), ('b', structure.Struct([ ('x', structure.Struct([('p', 4)])), ('y', 3), ('z', structure.Struct([('q', 5), ('r', 6)])), ])), ('c', 2), ]) self.assertEqual( structure.map_structure(lambda x, y: x + y, x, y), structure.Struct([ ('a', 11), ('b', structure.Struct([ ('x', structure.Struct([('p', 44)])), ('y', 33), ('z', structure.Struct([('q', 55), ('r', 66)])), ])), ('c', 22), ]))
def federated_secure_sum(self, value, bitwidth): """Implements `federated_secure_sum` as defined in `api/intrinsics.py`.""" value = value_impl.to_value(value, None, self._context_stack) value = value_utils.ensure_federated_value(value, placement_literals.CLIENTS, 'value to be summed') type_analysis.check_is_structure_of_integers(value.type_signature) bitwidth_value = value_impl.to_value(bitwidth, None, self._context_stack) value_member_type = value.type_signature.member bitwidth_type = bitwidth_value.type_signature if not type_analysis.is_valid_bitwidth_type_for_value_type( bitwidth_type, value_member_type): raise TypeError( 'Expected `federated_secure_sum` parameter `bitwidth` to match ' 'the structure of `value`, with one integer bitwidth per tensor in ' '`value`. Found `value` of `{}` and `bitwidth` of `{}`.'.format( value_member_type, bitwidth_type)) if bitwidth_type.is_tensor() and value_member_type.is_struct(): bitwidth_value = value_impl.to_value( structure.map_structure(lambda _: bitwidth, value_member_type), None, self._context_stack) value = value_impl.ValueImpl.get_comp(value) bitwidth_value = value_impl.ValueImpl.get_comp(bitwidth_value) comp = building_block_factory.create_federated_secure_sum( value, bitwidth_value) comp = self._bind_comp_as_reference(comp) return value_impl.ValueImpl(comp, self._context_stack)
def serialize_jax_computation(traced_fn, arg_fn, parameter_type, context_stack): """Serializes a Python function containing JAX code as a TFF computation. Args: traced_fn: The Python function containing JAX code to be traced by JAX and serialized as a TFF computation containing XLA code. arg_fn: An unpacking function that takes a TFF argument, and returns a combo of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed by `function_utils.create_argument_unpacking_fn`). parameter_type: An instance of `computation_types.Type` that represents the TFF type of the computation parameter, or `None` if the function does not take any parameters. context_stack: The context stack to use during serialization. Returns: An instance of `pb.Computation` with the constructed computation. Raises: TypeError: if the arguments are of the wrong types. """ py_typecheck.check_callable(traced_fn) py_typecheck.check_callable(arg_fn) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if parameter_type is not None: parameter_type = computation_types.to_type(parameter_type) packed_arg = _tff_type_to_xla_serializer_arg(parameter_type) else: packed_arg = None args, kwargs = arg_fn(packed_arg) def _adjust_arg(x): return type_conversions.type_to_py_container(x, x.type_signature) args = [_adjust_arg(x) for x in args] kwargs = {k: _adjust_arg(v) for k, v in kwargs.items()} context = jax_computation_context.JaxComputationContext() with context_stack.install(context): tracer_callable = jax.xla_computation(traced_fn, tuple_args=True, return_shape=True) compiled_xla, returned_shape = tracer_callable(*args, **kwargs) if isinstance(returned_shape, jax.ShapeDtypeStruct): returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor( returned_shape) else: returned_type_spec = computation_types.to_type( structure.map_structure( _jax_shape_dtype_struct_to_tff_tensor, structure.from_container(returned_shape, recursive=True))) computation_type = computation_types.FunctionType(parameter_type, returned_type_spec) return xla_serialization.create_xla_tff_computation( compiled_xla, computation_type)
def _ensure_structure(int_or_structure, int_or_structure_type, possible_struct_type): if int_or_structure_type.is_struct() or not possible_struct_type.is_struct(): return int_or_structure else: # Broadcast int_or_structure to the same structure as the struct type return structure.map_structure(lambda *args: int_or_structure, possible_struct_type)
def apply_sampling(accumulators, rands): size = tf.shape(rands)[0] k = tf.minimum(size, max_num_samples) indices = tf.math.top_k(rands, k=k).indices # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed. if member_type.is_struct(): return structure.map_structure(lambda v: fed_gather(v, indices), accumulators), fed_gather(rands, indices) return fed_gather(accumulators, indices), fed_gather(rands, indices)
def merge(a, b): """Merges accumulators through concatenation.""" # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed. if accumulator_type.is_struct(): samples = structure.map_structure(fed_concat, _ensure_structure(a), _ensure_structure(b)) else: samples = fed_concat(a, b) accumulators, rands = apply_sampling(samples.accumulators, samples.rands) return _Samples(accumulators, rands)
def accumlator_type_fn(): """Gets the type for the accumulators.""" # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed. if member_type.is_struct(): a = structure.map_structure( lambda v: tf.zeros([0] + v.shape.dims, v.dtype), member_type) return _Samples(structure.to_odict(a, True), tf.zeros([0], tf.float32)) if member_type.shape: s = [0] + member_type.shape.dims return _Samples(tf.zeros(s, member_type.dtype), tf.zeros([0], tf.float32))
def _validate_value_type_and_encoders(value_type, encoders, encoder_type): """Validates if `value_type` and `encoders` are compatible.""" if isinstance(encoders, _ALLOWED_ENCODERS): # If `encoders` is not a container, then `value_type` should be an instance # of `tff.TensorType.` if not isinstance(value_type, computation_types.TensorType): raise ValueError( '`value_type` and `encoders` do not have the same structure.') _validate_encoder(encoders, value_type, encoder_type) else: # If `encoders` is a container, then `value_type` should be an instance of # `tff.StructType.` if not isinstance(value_type, computation_types.StructType): raise TypeError('`value_type` is not compatible with the expected input ' 'of the `encoders`.') structure.map_structure(lambda e, v: _validate_encoder(e, v, encoder_type), structure.from_container(encoders, recursive=True), value_type)
def build_zero_argument(parameter_type): if parameter_type is None: return None elif parameter_type.is_struct(): return structure.map_structure(build_zero_argument, parameter_type) elif parameter_type == tffint32: return 0 elif parameter_type == tffstring: return '' else: raise NotImplementedError(f'Unsupported type: {parameter_type}')
def to_representation_for_type(value, type_spec, backend=None): """Verifies or converts the `value` to executor payload matching `type_spec`. The following kinds of `value` are supported: * Computations, either `pb.Computation` or `computation_impl.ComputationImpl`. * Numpy arrays and scalars, or Python scalars that are converted to Numpy. * Nested structures of the above. Args: value: The raw representation of a value to compare against `type_spec` and potentially to be converted. type_spec: An instance of `tff.Type`. Can be `None` for values that derive from `typed_object.TypedObject`. backend: The backend to use; an instance of `xla_client.Client`. Only used for functional types. Can be `None` if unused. Returns: Either `value` itself, or a modified version of it. Raises: TypeError: If the `value` is not compatible with `type_spec`. ValueError: If the arguments are incorrect. """ if backend is not None: py_typecheck.check_type(backend, xla_client.Client) if type_spec is not None: type_spec = computation_types.to_type(type_spec) type_spec = type_utils.reconcile_value_with_type_spec(value, type_spec) if isinstance(value, computation_base.Computation): return to_representation_for_type( computation_impl.ComputationImpl.get_proto(value), type_spec, backend) if isinstance(value, pb.Computation): comp_type = type_serialization.deserialize_type(value.type) if type_spec is not None: comp_type.check_equivalent_to(type_spec) return _ComputationCallable(value, comp_type, backend) if isinstance(type_spec, computation_types.StructType): return structure.map_structure( lambda v, t: to_representation_for_type(v, t, backend), structure.from_container(value, recursive=True), type_spec) if isinstance(type_spec, computation_types.TensorType): type_spec.shape.assert_is_fully_defined() type_analysis.check_type(value, type_spec) if type_spec.shape.rank == 0: return np.dtype(type_spec.dtype.as_numpy_dtype).type(value) if type_spec.shape.rank > 0: return np.array(value, dtype=type_spec.dtype.as_numpy_dtype) raise TypeError('Unsupported tensor shape {}.'.format(type_spec.shape)) raise TypeError('Unexpected type {}.'.format(type_spec))
def accumulate(current, value): """Accumulates samples through concatenation.""" rands = fed_concat_expand_dims(current.rands, tf.random.uniform(shape=())) # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed. if member_type.is_struct(): accumulators = structure.map_structure( fed_concat_expand_dims, _ensure_structure(current.accumulators), _ensure_structure(value)) else: accumulators = fed_concat_expand_dims(current.accumulators, value) accumulators, rands = apply_sampling(accumulators, rands) return _Samples(accumulators, rands)
def test_returns_value_with_intrinsic_def_federated_secure_sum( self, client_values, bitwidth, expected_result): executor = create_test_executor() value_type = computation_types.at_clients( type_conversions.infer_type(client_values[0])) bitwidth_type = type_conversions.infer_type(bitwidth) comp, comp_type = create_intrinsic_def_federated_secure_sum( value_type.member, bitwidth_type) comp = self.run_sync(executor.create_value(comp, comp_type)) arg_1 = self.run_sync(executor.create_value(client_values, value_type)) arg_2 = self.run_sync(executor.create_value(bitwidth, bitwidth_type)) args = self.run_sync(executor.create_struct([arg_1, arg_2])) result = self.run_sync(executor.create_call(comp, args)) self.assertIsInstance(result, executor_value_base.ExecutorValue) self.assert_types_identical(result.type_signature, comp_type.result) actual_result = self.run_sync(result.compute()) if isinstance(expected_result, structure.Struct): structure.map_structure(self.assertAllEqual, actual_result, expected_result) else: self.assertEqual(actual_result, expected_result)
def test_with_tuple_of_unnamed_elements(self): ex, _ = _make_executor_and_tracer_for_test() loop = asyncio.get_event_loop() v1 = loop.run_until_complete(ex.create_value(10, tf.int32)) self.assertEqual(str(v1.identifier), '1') v2 = loop.run_until_complete(ex.create_value(11, tf.int32)) self.assertEqual(str(v2.identifier), '2') v3 = loop.run_until_complete(ex.create_struct([v1, v2])) self.assertEqual(str(v3.identifier), '<1,2>') v4 = loop.run_until_complete(ex.create_struct((v1, v2))) self.assertIs(v4, v3) c4 = loop.run_until_complete(v4.compute()) self.assertEqual(str(structure.map_structure(lambda x: x.numpy(), c4)), '<10,11>')
def max_input_from_bitwidth(bitwidth): # Secure sum is performed with int64, which has 63 bits, and we need at # least one bit to hold the summation of two client values. max_secure_sum_bitwidth = 62 def compute_max_input(bits): assert_op = tf.Assert( tf.less_equal(bits, max_secure_sum_bitwidth), [ bits, f'is greater than maximum bitwidth {max_secure_sum_bitwidth}' ]) with tf.control_dependencies([assert_op]): return tf.math.pow(tf.constant(2, tf.int64), tf.cast(bits, tf.int64)) - 1 return structure.map_structure(compute_max_input, bitwidth)
def initialize(): # Allow fixed seeds, otherwise set a sentinel that signals a seed should be # generated upon the first `accumulate` call of the `federated_aggregate`. if seed is None: real_seed = tf.convert_to_tensor(SEED_SENTINEL, dtype=tf.int64) elif tf.is_tensor(seed): if seed.dtype != tf.int64: real_seed = tf.cast(seed, dtype=tf.int64) else: real_seed = tf.convert_to_tensor(seed, dtype=tf.int64) def zero_for_tensor_type(t: computation_types.TensorType): """Add an extra first dimension to create a tensor that collects samples. The first dimension will have size `0` for the algebraic zero, resulting in an "empty" tensor. This will be conctenated as samples fill the reservoir. Args: t: A `tff.TensorType` to build a sampling zero value for. Returns: A tensor whose rank is one larger than before, and whose first dimension is zero. Raises: `TypeError` if `t` is not a `tff.TensorType`. """ if not t.is_tensor(): raise TypeError( f'Cannot create zero for non TesnorType: {type(t)}') return tf.zeros([0] + t.shape, dtype=t.dtype) if sample_value_type.is_tensor(): initial_samples = zero_for_tensor_type(sample_value_type) elif sample_value_type.is_struct(): initial_samples = structure.map_structure(zero_for_tensor_type, sample_value_type) else: raise TypeError( 'Cannot build initial reservoir for structure that has ' 'types other than StructWithPythonType or TensorType, ' f'got {sample_value_type!r}.') return collections.OrderedDict(random_seed=tf.fill(dims=(2, ), value=real_seed), random_values=tf.zeros([0], tf.int32), samples=initial_samples)