async def create_struct(self, elements): """Creates a tuple of `elements`. Args: elements: As documented in `executor_base.Executor`. Returns: An instance of `EagerValue` that represents the constructed tuple. """ elements = structure.to_elements(structure.from_container(elements)) val_elements = [] type_elements = [] for k, v in elements: py_typecheck.check_type(v, EagerValue) val_elements.append((k, v.internal_representation)) type_elements.append((k, v.type_signature)) return EagerValue( structure.Struct(val_elements), self._tf_function_cache, computation_types.StructType([(k, v) if k is not None else v for k, v in type_elements]))
async def create_struct(self, elements): constructed_anon_tuple = structure.from_container(elements) proto_elem = [] type_elem = [] for k, v in structure.iter_elements(constructed_anon_tuple): py_typecheck.check_type(v, RemoteValue) proto_elem.append( executor_pb2.CreateStructRequest.Element( name=(k if k else None), value_ref=v.value_ref)) type_elem.append((k, v.type_signature) if k else v.type_signature) result_type = computation_types.StructType(type_elem) request = executor_pb2.CreateStructRequest(element=proto_elem) if self._bidi_stream is None: response = _request(self._stub.CreateStruct, request) else: response = (await self._bidi_stream.send_request( executor_pb2.ExecuteRequest(create_struct=request) )).create_struct py_typecheck.check_type(response, executor_pb2.CreateStructResponse) return RemoteValue(response.value_ref, result_type, self)
async def create_value(self, value, type_spec=None): type_spec = computation_types.to_type(type_spec) if isinstance(value, computation_impl.ComputationImpl): return await self.create_value( computation_impl.ComputationImpl.get_proto(value), type_utils.reconcile_value_with_type_spec(value, type_spec)) elif isinstance(value, pb.Computation): return await self._evaluate(value) elif type_spec is not None and type_spec.is_struct(): v_el = structure.to_elements(structure.from_container(value)) vals = await asyncio.gather(*[ self.create_value(val, t) for (_, val), t in zip(v_el, type_spec) ]) return ReferenceResolvingExecutorValue( structure.Struct( (name, val) for (name, _), val in zip(v_el, vals))) else: return ReferenceResolvingExecutorValue( await self._target_executor.create_value(value, type_spec))
def infer_cardinalities(value, type_spec): """Infers cardinalities from Python `value`. Allows for any Python object to represent a federated value; enforcing particular representations is not the job of this inference function, but rather ingestion functions lower in the stack. Args: value: Python object from which to infer TFF placement cardinalities. type_spec: The TFF type spec for `value`, determining the semantics for inferring cardinalities. That is, we only pull the cardinality off of federated types. Returns: Dict of cardinalities. Raises: ValueError: If conflicting cardinalities are inferred from `value`. TypeError: If the arguments are of the wrong types, or if `type_spec` is a federated type which is not `all_equal` but the yet-to-be-embedded `value` is not represented as a Python `list`. """ py_typecheck.check_not_none(value) py_typecheck.check_type(type_spec, computation_types.Type) if isinstance(value, cardinality_carrying_base.CardinalityCarrying): return value.cardinality if type_spec.is_federated(): if type_spec.all_equal: return {} py_typecheck.check_type(value, collections.abc.Sized) return {type_spec.placement: len(value)} elif type_spec.is_struct(): structure_value = structure.from_container(value, recursive=False) cardinality_dict = {} for idx, (_, elem_type) in enumerate(structure.to_elements(type_spec)): cardinality_dict = merge_cardinalities( cardinality_dict, infer_cardinalities(structure_value[idx], elem_type)) return cardinality_dict else: return {}
def _to_struct_internal_rep( *, value: Any, tf_function_cache: MutableMapping[str, Any], type_spec: computation_types.StructType, device: tf.config.LogicalDevice) -> structure.Struct: """Converts a python container to internal representation for TF executor.""" type_elem = structure.to_elements(type_spec) value_elem = (structure.to_elements(structure.from_container(value))) result_elem = [] if len(type_elem) != len(value_elem): raise TypeError('Expected a {}-element tuple, found {} elements.'.format( len(type_elem), len(value_elem))) for (type_name, elem_type), (val_name, elem_val) in zip(type_elem, value_elem): if type_name != val_name: raise TypeError( 'Mismatching element names in type vs. value: {} vs. {}.'.format( type_name, val_name)) elem_repr = to_representation_for_type(elem_val, tf_function_cache, elem_type, device) result_elem.append((type_name, elem_repr)) return structure.Struct(result_elem)
async def create_struct(self, elements): if not isinstance(elements, structure.Struct): elements = structure.from_container(elements) element_strings = [] element_kv_pairs = structure.to_elements(elements) to_gather = [] type_elements = [] for k, v in element_kv_pairs: py_typecheck.check_type(v, CachedValue) to_gather.append(v.target_future) if k is not None: py_typecheck.check_type(k, str) element_strings.append('{}={}'.format(k, v.identifier)) type_elements.append((k, v.type_signature)) else: element_strings.append(str(v.identifier)) type_elements.append(v.type_signature) type_spec = computation_types.StructType(type_elements) gathered = await asyncio.gather(*to_gather) identifier = CachedValueIdentifier('<{}>'.format(','.join(element_strings))) try: cached_value = self._cache[identifier] except KeyError: target_future = asyncio.ensure_future( self._target_executor.create_struct( structure.Struct( (k, v) for (k, _), v in zip(element_kv_pairs, gathered)))) cached_value = CachedValue(identifier, None, type_spec, target_future) self._cache[identifier] = cached_value try: target_value = await cached_value.target_future except Exception: # TODO(b/145514490): This is a bit heavy handed, there maybe caches where # only the current cache item needs to be invalidated; however this # currently only occurs when an inner RemoteExecutor has the backend go # down. self._cache = {} raise type_spec.check_assignable_from(target_value.type_signature) return cached_value
def _serialize_struct_type( value_proto: serialization_bindings.Value, struct_typed_value: Any, type_spec: computation_types.StructType, ) -> computation_types.StructType: """Serializes a value of tuple type.""" value_structure = structure.from_container(struct_typed_value) if len(value_structure) != len(type_spec): raise TypeError( 'Cannot serialize a struct value of ' f'{len(value_structure)} elements to a struct type ' f'requiring {len(type_spec)} elements. Trying to serialize' f'\n{struct_typed_value!r}\nto\n{type_spec}.') type_elem_iter = structure.iter_elements(type_spec) val_elem_iter = structure.iter_elements(value_structure) struct_proto = value_proto.struct for (e_name, e_type), (_, e_val) in zip(type_elem_iter, val_elem_iter): element_proto = struct_proto.element.add() serialize_value(e_val, e_type, value_proto=element_proto.value) if e_name: element_proto.name = e_name return value_proto, type_spec
async def _encrypt_values_on_singleton(self, val, sender, receiver): ### # we can safely assume sender has cardinality=1 when receiver is CLIENTS ### # Case 1: receiver=CLIENTS # plaintext: Fed(Tensor, sender, all_equal=True) # pk_receiver: Fed(Tuple(Tensor), sender, all_equal=True) # sk_sender: Fed(Tensor, sender, all_equal=True) # Returns: # encrypted_values: Tuple(Fed(Tensor, sender, all_equal=True)) ### ### Check proper key placement sk_sender = self.key_references.get_secret_key(sender) pk_receiver = self.key_references.get_public_key(receiver) type_analysis.check_federated_type(sk_sender.type_signature, placement=sender) assert sk_sender.type_signature.placement is sender assert pk_receiver.type_signature.placement is sender ### Check placement cardinalities rcv_children = self.strategy._get_child_executors(receiver) snd_children = self.strategy._get_child_executors(sender) py_typecheck.check_len(snd_children, 1) snd_child = snd_children[0] ### Check value cardinalities type_analysis.check_federated_type(val.type_signature, placement=sender) py_typecheck.check_len(val.internal_representation, 1) py_typecheck.check_type(pk_receiver.type_signature.member, tff.StructType) py_typecheck.check_len(pk_receiver.internal_representation, len(rcv_children)) py_typecheck.check_len(sk_sender.internal_representation, 1) ### Materialize encryptor function definition & type spec input_type = val.type_signature.member self._input_type_cache = input_type pk_rcv_type = pk_receiver.type_signature.member sk_snd_type = sk_sender.type_signature.member pk_element_type = pk_rcv_type[0] encryptor_arg_spec = (input_type, pk_element_type, sk_snd_type) encryptor_proto, encryptor_type = utils.materialize_computation_from_cache( sodium_comp.make_encryptor, self._encryptor_cache, encryptor_arg_spec) ### Prepare encryption arguments v = val.internal_representation[0] sk = sk_sender.internal_representation[0] ### Encrypt values and return them encryptor_fn = await snd_child.create_value(encryptor_proto, encryptor_type) encryptor_args = await asyncio.gather(*[ snd_child.create_struct([v, this_pk, sk]) for this_pk in pk_receiver.internal_representation ]) encrypted_values = await asyncio.gather(*[ snd_child.create_call(encryptor_fn, arg) for arg in encryptor_args ]) encrypted_value_types = [encryptor_type.result] * len(encrypted_values) return federated_resolving_strategy.FederatedResolvingStrategyValue( structure.from_container(encrypted_values), tff.StructType([ tff.FederatedType(evt, sender, all_equal=False) for evt in encrypted_value_types ]))
def serialize_jax_computation(traced_fn, arg_fn, parameter_type, context_stack): """Serializes a Python function containing JAX code as a TFF computation. Args: traced_fn: The Python function containing JAX code to be traced by JAX and serialized as a TFF computation containing XLA code. arg_fn: An unpacking function that takes a TFF argument, and returns a combo of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed by `function_utils.create_argument_unpacking_fn`). parameter_type: An instance of `computation_types.Type` that represents the TFF type of the computation parameter, or `None` if the function does not take any parameters. context_stack: The context stack to use during serialization. Returns: An instance of `pb.Computation` with the constructed computation. Raises: TypeError: if the arguments are of the wrong types. """ py_typecheck.check_callable(traced_fn) py_typecheck.check_callable(arg_fn) py_typecheck.check_type(context_stack, context_stack_base.ContextStack) if parameter_type is not None: parameter_type = computation_types.to_type(parameter_type) packed_arg = _tff_type_to_xla_serializer_arg(parameter_type) else: packed_arg = None args, kwargs = arg_fn(packed_arg) # While the fake parameters are fed via args/kwargs during serialization, # it is possible for them to get reordered in the actual generated XLA code. # We use here the same flattening function as that one, which is used by # the JAX serializer to determine the ordering and allow it to be captured # in the parameter binding. We do not need to do anything special for the # results, since the results, if multiple, are always returned as a tuple. flattened_obj, _ = jax.tree_util.tree_flatten((args, kwargs)) tensor_indexes = list(np.argsort([x.tensor_index for x in flattened_obj])) context = jax_computation_context.JaxComputationContext() with context_stack.install(context): tracer_callable = jax.xla_computation(traced_fn, tuple_args=True, return_shape=True) compiled_xla, returned_shape = tracer_callable(*args, **kwargs) if isinstance(returned_shape, jax.ShapeDtypeStruct): returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor( returned_shape) else: returned_type_spec = computation_types.to_type( structure.map_structure( _jax_shape_dtype_struct_to_tff_tensor, structure.from_container(returned_shape, recursive=True))) computation_type = computation_types.FunctionType(parameter_type, returned_type_spec) return xla_serialization.create_xla_tff_computation( compiled_xla, tensor_indexes, computation_type)
async def create_struct(self, elements): return ReferenceResolvingExecutorValue(structure.from_container(elements))
def create_constant( value, type_spec: computation_types.Type) -> ComputationProtoAndType: """Returns a tensorflow computation returning a constant `value`. The returned computation has the type signature `( -> T)`, where `T` is `type_spec`. `value` must be a value convertible to a tensor or a structure of values, such that the dtype and shapes match `type_spec`. `type_spec` must contain only named tuples and tensor types, but these can be arbitrarily nested. Args: value: A value to embed as a constant in the tensorflow graph. type_spec: A `computation_types.Type` to use as the argument to the constructed binary operator; must contain only named tuples and tensor types. Raises: TypeError: If the constraints of `type_spec` are violated. """ if not type_analysis.is_generic_op_compatible_type(type_spec): raise TypeError( 'Type spec {} cannot be constructed as a TensorFlow constant in TFF; ' ' only nested tuples and tensors are permitted.'.format(type_spec)) inferred_value_type = type_conversions.infer_type(value) if (inferred_value_type.is_struct() and not type_spec.is_assignable_from(inferred_value_type)): raise TypeError( 'Must pass a only tensor or structure of tensor values to ' '`create_tensorflow_constant`; encountered a value {v} with inferred ' 'type {t!r}, but needed {s!r}'.format(v=value, t=inferred_value_type, s=type_spec)) if inferred_value_type.is_struct(): value = structure.from_container(value, recursive=True) tensor_dtypes_in_type_spec = [] def _pack_dtypes(type_signature): """Appends dtype of `type_signature` to nonlocal variable.""" if type_signature.is_tensor(): tensor_dtypes_in_type_spec.append(type_signature.dtype) return type_signature, False type_transformations.transform_type_postorder(type_spec, _pack_dtypes) if (any(x.is_integer for x in tensor_dtypes_in_type_spec) and (inferred_value_type.is_tensor() and not inferred_value_type.dtype.is_integer)): raise TypeError( 'Only integers can be used as scalar values if our desired constant ' 'type spec contains any integer tensors; passed scalar {} of dtype {} ' 'for type spec {}.'.format(value, inferred_value_type.dtype, type_spec)) result_type = type_spec def _create_result_tensor(type_spec, value): """Packs `value` into `type_spec` recursively.""" if type_spec.is_tensor(): type_spec.shape.assert_is_fully_defined() result = tf.constant(value, dtype=type_spec.dtype, shape=type_spec.shape) else: elements = [] if inferred_value_type.is_struct(): # Copy the leaf values according to the type_spec structure. for (name, elem_type), value in zip( structure.iter_elements(type_spec), value): elements.append( (name, _create_result_tensor(elem_type, value))) else: # "Broadcast" the value to each level of the type_spec structure. for _, elem_type in structure.iter_elements(type_spec): elements.append( (None, _create_result_tensor(elem_type, value))) result = structure.Struct(elements) return result with tf.Graph().as_default() as graph: result = _create_result_tensor(result_type, value) _, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(None, result_type) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=None, result=result_binding) return _tensorflow_comp(tensorflow, type_signature)
def test_from_container_with_namedtuple(self): x = structure.from_container(collections.namedtuple('_', 'x y')(1, 2)) self.assertIsInstance(x, structure.Struct) self.assertEqual(str(x), '<x=1,y=2>')
def test_from_container_with_dict(self): x = structure.from_container({'z': 10, 'y': 20, 'a': 30}) self.assertIsInstance(x, structure.Struct) self.assertEqual(str(x), '<a=30,y=20,z=10>')
def to_representation_for_type( value: Any, tf_function_cache: MutableMapping[str, Any], type_spec: Optional[computation_types.Type] = None, device: Optional[tf.config.LogicalDevice] = None) -> Any: """Verifies or converts the `value` to an eager object matching `type_spec`. WARNING: This function is only partially implemented. It does not support data sets at this point. The output of this function is always an eager tensor, eager dataset, a representation of a TensorFlow computation, or a nested structure of those that matches `type_spec`, and when `device` has been specified, everything is placed on that device on a best-effort basis. TensorFlow computations are represented here as zero- or one-argument Python callables that accept their entire argument bundle as a single Python object. Args: value: The raw representation of a value to compare against `type_spec` and potentially to be converted. tf_function_cache: A cache obeying `dict` semantics that can be used to look up previously embedded TensorFlow functions. type_spec: An instance of `tff.Type`, can be `None` for values that derive from `typed_object.TypedObject`. device: An optional `tf.config.LogicalDevice` to place the value on (for tensor-level values). Returns: Either `value` itself, or a modified version of it. Raises: TypeError: If the `value` is not compatible with `type_spec`. """ type_spec = type_utils.reconcile_value_with_type_spec(value, type_spec) if isinstance(value, computation_base.Computation): return to_representation_for_type( computation_impl.ComputationImpl.get_proto(value), tf_function_cache, type_spec, device) elif isinstance(value, pb.Computation): key = (value.SerializeToString(), str(type_spec), device.name if device else None) cached_fn = tf_function_cache.get(key) if cached_fn is not None: return cached_fn embedded_fn = embed_tensorflow_computation(value, type_spec, device) tf_function_cache[key] = embedded_fn return embedded_fn elif type_spec.is_struct(): type_elem = structure.to_elements(type_spec) value_elem = (structure.to_elements(structure.from_container(value))) result_elem = [] if len(type_elem) != len(value_elem): raise TypeError( 'Expected a {}-element tuple, found {} elements.'.format( len(type_elem), len(value_elem))) for (t_name, el_type), (v_name, el_val) in zip(type_elem, value_elem): if t_name != v_name: raise TypeError( 'Mismatching element names in type vs. value: {} vs. {}.'. format(t_name, v_name)) el_repr = to_representation_for_type(el_val, tf_function_cache, el_type, device) result_elem.append((t_name, el_repr)) return structure.Struct(result_elem) elif device is not None: py_typecheck.check_type(device, tf.config.LogicalDevice) with tf.device(device.name): return to_representation_for_type(value, tf_function_cache, type_spec=type_spec, device=None) elif isinstance(value, EagerValue): return value.internal_representation elif isinstance(value, executor_value_base.ExecutorValue): raise TypeError( 'Cannot accept a value embedded within a non-eager executor.') elif type_spec.is_tensor(): if not tf.is_tensor(value): value = tf.convert_to_tensor(value, dtype=type_spec.dtype) elif hasattr(value, 'read_value'): # a tf.Variable-like result, get a proper tensor. value = value.read_value() value_type = (computation_types.TensorType(value.dtype.base_dtype, value.shape)) if not type_spec.is_assignable_from(value_type): raise TypeError( 'The apparent type {} of a tensor {} does not match the expected ' 'type {}.'.format(value_type, value, type_spec)) return value elif type_spec.is_sequence(): if isinstance(value, list): value = tensorflow_utils.make_data_set_from_elements( None, value, type_spec.element) py_typecheck.check_type( value, type_conversions.TF_DATASET_REPRESENTATION_TYPES) element_type = computation_types.to_type(value.element_spec) value_type = computation_types.SequenceType(element_type) type_spec.check_assignable_from(value_type) return value else: raise TypeError('Unexpected type {}.'.format(type_spec))
async def _to_structure(coro: Coroutine[Any, Any, Any]) -> structure.Struct: return structure.from_container(await coro)
def test_from_container_with_int(self): with self.assertRaises(TypeError): structure.from_container(10)
def test_from_container_with_tuple(self): x = structure.from_container(tuple([10, 20])) self.assertIsInstance(x, structure.Struct) self.assertEqual(str(x), '<10,20>')
def test_from_container_sparse_tensor(self): x = structure.from_container( tf.SparseTensor(indices=[[1]], values=[2], dense_shape=[5])) self.assertEqual(str(x), '<indices=[[1]],values=[2],dense_shape=[5]>')
def test_from_container_with_ordered_dict(self): x = structure.from_container( collections.OrderedDict([('z', 10), ('y', 20), ('a', 30)])) self.assertIsInstance(x, structure.Struct) self.assertEqual(str(x), '<z=10,y=20,a=30>')
def test_from_container_asdict_roundtrip(self, dict_in): structure_repr = structure.from_container(dict_in, recursive=True) dict_out = structure_repr._asdict(recursive=True) self.assertEqual(dict_in, dict_out)
def test_from_container_with_struct(self): x = structure.from_container(structure.Struct([('a', 10), ('b', 20)])) self.assertIs(x, x)
def test_from_container_raises_on_non_container_argument(self): with self.assertRaises(TypeError): structure.from_container(3)
def test_to_odict_or_tuple_from_container_roundtrip(self, original): structure_repr = structure.from_container(original, recursive=True) out = structure.to_odict_or_tuple(structure_repr) self.assertEqual(original, out)
def test_to_odict_or_tuple_empty_dict_becomes_empty_tuple(self): s = collections.OrderedDict() x = structure.from_container(s) self.assertEqual(structure.to_odict_or_tuple(x), ())
def _ensure_structure(obj): return structure.from_container(obj, True)
def test_to_odict_or_tuple_mixed_nonrecursive(self): s = ODict(a=1, b=2, c=(3, ODict(d=4, e=5))) x = structure.from_container(s, recursive=False) self.assertEqual(s, structure.to_odict_or_tuple(x, recursive=False))
def result_computation(arg): if parameter_type.is_struct(): arg = structure.from_container(arg, recursive=True) return _evaluate_to_tensorflow(to_evaluate, {parameter_name: arg})
async def create_value(self, value, type_spec=None): """Creates a value in this executor. The following kinds of `value` are supported as the input: * An instance of TFF computation proto containing one of the supported sequence intrinsics as its sole body. * An instance of eager TF dataset. * Anything that is supported by the target executor (as a pass-through). * A nested structure of any of the above. Args: value: The input for which to create a value. type_spec: An optional TFF type (required if `value` is not an instance of `typed_object.TypedObject`, otherwise it can be `None`). Returns: An instance of `SequenceExecutorValue` that represents the embedded value. """ if type_spec is None: py_typecheck.check_type(value, typed_object.TypedObject) type_spec = value.type_signature else: type_spec = computation_types.to_type(type_spec) if isinstance(type_spec, computation_types.SequenceType): return SequenceExecutorValue( _SequenceFromPayload(value, type_spec), type_spec) if isinstance(value, pb.Computation): value_type = type_serialization.deserialize_type(value.type) value_type.check_equivalent_to(type_spec) which_computation = value.WhichOneof('computation') # NOTE: If not a supported type of intrinsic, we let it fall through and # be handled by embedding in the target executor (below). if which_computation == 'intrinsic': intrinsic_def = intrinsic_defs.uri_to_intrinsic_def( value.intrinsic.uri) if intrinsic_def is None: raise ValueError( 'Encountered an unrecognized intrinsic "{}".'.format( value.intrinsic.uri)) op_type = SequenceExecutor._SUPPORTED_INTRINSIC_TO_SEQUENCE_OP.get( intrinsic_def.uri) if op_type is not None: type_analysis.check_concrete_instance_of( type_spec, intrinsic_def.type_signature) op = op_type(type_spec) return SequenceExecutorValue(op, type_spec) if isinstance(type_spec, computation_types.StructType): if not isinstance(value, structure.Struct): value = structure.from_container(value) elements = structure.flatten(value) element_types = structure.flatten(type_spec) flat_embedded_vals = await asyncio.gather(*[ self.create_value(el, el_type) for el, el_type in zip(elements, element_types) ]) embedded_struct = structure.pack_sequence_as( value, flat_embedded_vals) return await self.create_struct(embedded_struct) target_value = await self._target_executor.create_value( value, type_spec) return SequenceExecutorValue(target_value, type_spec)
class FederatingExecutorCreateCallTest(executor_test_utils.AsyncTestCase, parameterized.TestCase, test_case.TestCase): @parameterized.named_parameters( ( 'scalar_sum_less_than_mask', [10, 11, 12], 10, # larger than sum and individual client values. 33, # 33 mod (2**(10+2)) ), ( 'scalar_sum_less_than_extended_mask', [10, 11, 12], 4, # larger than indivudal client values, but smaller than sum. 33, # 33 mod (2**(4+2)) ), ( 'scalar_sum_greater_than_extend_mask', [10, 11, 12], 2, # smaller than individual client values. 1, # 33 mod (2**(2+2)) ), ( 'structured_scalar_sum_less_than_mask', [[10, 0], [11, 1], [12, 2]], [10, 10], structure.from_container([33, 3]), ), ( 'structued_scalar_sum_greater_than_mask', [[10, 0], [11, 1], [12, 2]], [3, 3], structure.from_container([1, 3]), ), ( 'named_structured_scalar_sum_less_than_mask', [ collections.OrderedDict(a=10, b=0), collections.OrderedDict(a=11, b=1), collections.OrderedDict(a=12, b=2), ], collections.OrderedDict(a=10, b=10), structure.from_container(collections.OrderedDict(a=33, b=3)), ), ( 'named_structued_scalar_sum_greater_than_mask', [ collections.OrderedDict(a=10, b=0), collections.OrderedDict(a=11, b=1), collections.OrderedDict(a=12, b=2), ], collections.OrderedDict(a=3, b=3), structure.from_container(collections.OrderedDict(a=1, b=3)), ), ( 'nested_structured_tensor_sum', [ collections.OrderedDict( a=100, b=[tf.constant([0, 1, 2]), tf.constant([10, 70])]), collections.OrderedDict( a=1000, b=[tf.constant([1, 2, 3]), tf.constant([20, 80])]), collections.OrderedDict( a=10000, b=[tf.constant([2, 3, 4]), tf.constant([30, 90])]), ], collections.OrderedDict(a=16, b=[4, 5]), structure.from_container( collections.OrderedDict(a=11100, b=structure.from_container([ tf.constant([3, 6, 9]), tf.constant([60, 240 & (2**7 - 1)]), ]))), ), ) def test_returns_value_with_intrinsic_def_federated_secure_sum( self, client_values, bitwidth, expected_result): executor = create_test_executor() value_type = computation_types.at_clients( type_conversions.infer_type(client_values[0])) bitwidth_type = type_conversions.infer_type(bitwidth) comp, comp_type = create_intrinsic_def_federated_secure_sum( value_type.member, bitwidth_type) comp = self.run_sync(executor.create_value(comp, comp_type)) arg_1 = self.run_sync(executor.create_value(client_values, value_type)) arg_2 = self.run_sync(executor.create_value(bitwidth, bitwidth_type)) args = self.run_sync(executor.create_struct([arg_1, arg_2])) result = self.run_sync(executor.create_call(comp, args)) self.assertIsInstance(result, executor_value_base.ExecutorValue) self.assert_types_identical(result.type_signature, comp_type.result) actual_result = self.run_sync(result.compute()) if isinstance(expected_result, structure.Struct): structure.map_structure(self.assertAllEqual, actual_result, expected_result) else: self.assertEqual(actual_result, expected_result)
def test_from_container_with_namedtuple_of_odict_recursive(self): x = structure.from_container(collections.namedtuple('_', 'x y')( collections.OrderedDict([('a', 10), ('b', 20)]), collections.OrderedDict([('c', 30), ('d', 40)])), recursive=True) self.assertEqual(str(x), '<x=<a=10,b=20>,y=<c=30,d=40>>')