def _to_struct_internal_rep( *, value: Any, tf_function_cache: MutableMapping[str, Any], type_spec: computation_types.StructType, device: tf.config.LogicalDevice) -> structure.Struct: """Converts a python container to internal representation for TF executor.""" type_iterator = structure.iter_elements(type_spec) value_struct = structure.from_container(value) value_iterator = structure.iter_elements(value_struct) if len(type_spec) != len(value_struct): raise TypeError('Mismatched number of elements between type spec and value ' 'in `to_representation_for_type`. Type spec has {} ' 'elements, value has {}.'.format( len(type_spec), len(value_struct))) result_elem = [] for (type_name, elem_type), (val_name, elem_val) in zip(type_iterator, value_iterator): if val_name is not None and type_name != val_name: raise TypeError( 'Mismatching element names in type vs. value: {} vs. {}.'.format( type_name, val_name)) elem_repr = to_representation_for_type(elem_val, tf_function_cache, elem_type, device) result_elem.append((type_name, elem_repr)) return structure.Struct(result_elem)
def is_valid_bitwidth_type_for_value_type( bitwidth_type: computation_types.Type, value_type: computation_types.Type) -> bool: """Whether or not `bitwidth_type` is a valid bitwidth type for `value_type`.""" # NOTE: this function is primarily a helper for `intrinsic_factory.py`'s # `federated_secure_sum` function. py_typecheck.check_type(bitwidth_type, computation_types.Type) py_typecheck.check_type(value_type, computation_types.Type) if value_type.is_tensor() and bitwidth_type.is_tensor(): # Here, `value_type` refers to a tensor. Rather than check that # `bitwidth_type` is exactly the same, we check that it is a single integer, # since we want a single bitwidth integer per tensor. return bitwidth_type.dtype.is_integer and ( bitwidth_type.shape.num_elements() == 1) elif value_type.is_struct() and bitwidth_type.is_struct(): bitwidth_name_and_types = list(structure.iter_elements(bitwidth_type)) value_name_and_types = list(structure.iter_elements(value_type)) if len(bitwidth_name_and_types) != len(value_name_and_types): return False for (inner_bitwidth_name, inner_bitwidth_type), (inner_value_name, inner_value_type) in zip( bitwidth_name_and_types, value_name_and_types): if inner_bitwidth_name != inner_value_name: return False if not is_valid_bitwidth_type_for_value_type(inner_bitwidth_type, inner_value_type): return False return True else: return False
def is_valid_bitwidth_type_for_value_type( bitwidth_type: computation_types.Type, value_type: computation_types.Type) -> bool: """Whether or not `bitwidth_type` is a valid bitwidth type for `value_type`.""" # NOTE: this function is primarily a helper for `intrinsic_factory.py`'s # `federated_secure_sum` function. py_typecheck.check_type(bitwidth_type, computation_types.Type) py_typecheck.check_type(value_type, computation_types.Type) if bitwidth_type.is_tensor(): # This condition applies to both `value_type` being a tensor or structure, # as the same integer bitwidth can be used for all values in the structure. return bitwidth_type.dtype.is_integer and ( bitwidth_type.shape.num_elements() == 1) elif value_type.is_struct() and bitwidth_type.is_struct(): bitwidth_name_and_types = list(structure.iter_elements(bitwidth_type)) value_name_and_types = list(structure.iter_elements(value_type)) if len(bitwidth_name_and_types) != len(value_name_and_types): return False for (inner_bitwidth_name, inner_bitwidth_type), (inner_value_name, inner_value_type) in zip( bitwidth_name_and_types, value_name_and_types): if inner_bitwidth_name != inner_value_name: return False if not is_valid_bitwidth_type_for_value_type(inner_bitwidth_type, inner_value_type): return False return True else: return False
def is_single_integer_or_matches_structure( type_sig: computation_types.Type, shape_type: computation_types.Type) -> bool: """If `type_sig` is an integer or integer structure matching `shape_type`.""" py_typecheck.check_type(type_sig, computation_types.Type) py_typecheck.check_type(shape_type, computation_types.Type) if type_sig.is_tensor(): # This condition applies to both `shape_type` being a tensor or structure, # as the same integer bitwidth can be used for all values in the structure. return type_sig.dtype.is_integer and (type_sig.shape.num_elements() == 1) elif shape_type.is_struct() and type_sig.is_struct(): bitwidth_name_and_types = list(structure.iter_elements(type_sig)) shape_name_and_types = list(structure.iter_elements(shape_type)) if len(type_sig) != len(shape_name_and_types): return False for (inner_name, type_sig), (inner_shape_name, inner_shape_type) in zip( bitwidth_name_and_types, shape_name_and_types): if inner_name != inner_shape_name: return False if not is_single_integer_or_matches_structure(type_sig, inner_shape_type): return False return True else: return False
def _serialize_struct_type( struct_typed_value: Any, type_spec: computation_types.StructType, ) -> computation_types.StructType: """Serializes a value of tuple type.""" value_structure = structure.from_container(struct_typed_value) if len(value_structure) != len(type_spec): raise TypeError( 'Cannot serialize a struct value of ' f'{len(value_structure)} elements to a struct type ' f'requiring {len(type_spec)} elements. Trying to serialize' f'\n{struct_typed_value!r}\nto\n{type_spec}.') type_elem_iter = structure.iter_elements(type_spec) val_elem_iter = structure.iter_elements(value_structure) elements = [] for (e_name, e_type), (_, e_val) in zip(type_elem_iter, val_elem_iter): e_value, _ = serialize_value(e_val, e_type) if e_name: element = executor_pb2.Value.Struct.Element(name=e_name, value=e_value) else: element = executor_pb2.Value.Struct.Element(value=e_value) elements.append(element) value_proto = executor_pb2.Value(struct=executor_pb2.Value.Struct( element=elements)) return value_proto, type_spec
def from_tff_result(cls, struct): py_typecheck.check_type(struct, structure.Struct) return cls( [value for _, value in structure.iter_elements(struct.trainable)], [ value for _, value in structure.iter_elements(struct.non_trainable) ])
async def create_struct(self, elements): elements_as_structure = structure.from_container(elements) elements_iter = structure.iter_elements(elements_as_structure) pairs = ((n, v.internal_representation) for (n, v) in elements_iter) inner_elements = structure.Struct(pairs) return await self._delegate( self._target_executor.create_struct(inner_elements))
def _create_structure_of_coro_references( coro: Coroutine[Any, Any, Any], type_signature: computation_types.Type) -> Any: """Returns a structure of `tff.program.CoroValueReference`s.""" py_typecheck.check_type(type_signature, computation_types.Type) if type_signature.is_struct(): async def _to_structure(coro: Coroutine[Any, Any, Any]) -> structure.Struct: return structure.from_container(await coro) coro = _to_structure(coro) shared_awaitable = async_utils.SharedAwaitable(coro) async def _get_item(awaitable: Awaitable[structure.Struct], index: int) -> Any: value = await awaitable return value[index] elements = [] element_types = structure.iter_elements(type_signature) for index, (name, element_type) in enumerate(element_types): element_coro = _get_item(shared_awaitable, index) element = _create_structure_of_coro_references(element_coro, element_type) elements.append((name, element)) return structure.Struct(elements) elif (type_signature.is_federated() and type_signature.placement == placements.SERVER): return _create_structure_of_coro_references(coro, type_signature.member) elif type_signature.is_sequence(): return CoroValueReference(coro, type_signature) elif type_signature.is_tensor(): return CoroValueReference(coro, type_signature) else: raise NotImplementedError(f'Unexpected type found: {type_signature}.')
async def _materialize_structure_of_value_references( value: Any, type_signature: computation_types.Type) -> Any: """Returns a structure of materialized values.""" py_typecheck.check_type(type_signature, computation_types.Type) async def _materialize(value: Any) -> Any: if isinstance(value, value_reference.MaterializableValueReference): return await value.get_value() else: return value if type_signature.is_struct(): value = structure.from_container(value) element_types = list(structure.iter_elements(type_signature)) element_coros = [ _materialize_structure_of_value_references(v, t) for v, (_, t) in zip(value, element_types) ] elements = await asyncio.gather(*element_coros) elements = [(n, v) for v, (n, _) in zip(elements, element_types)] return structure.Struct(elements) elif (type_signature.is_federated() and type_signature.placement == placements.SERVER): return await _materialize_structure_of_value_references( value, type_signature.member) elif type_signature.is_sequence(): return await _materialize(value) elif type_signature.is_tensor(): return await _materialize(value) else: return value
def _repackage_partitioned_values(after_merge_results, result_type_spec: computation_types.Type): """Inverts `_split_value_into_subrounds` above.""" py_typecheck.check_type(after_merge_results, list) if result_type_spec.is_struct(): after_merge_structs = [ structure.from_container(x) for x in after_merge_results ] result_container = [] for idx, (name, elem_type) in enumerate( structure.iter_elements(result_type_spec)): result_container.append( (name, _repackage_partitioned_values( [x[idx] for x in after_merge_structs], elem_type))) return structure.Struct(result_container) elif result_type_spec.is_federated( ) and result_type_spec.placement.is_clients(): if result_type_spec.all_equal: return after_merge_results[0] for x in after_merge_results: py_typecheck.check_type(x, (list, tuple)) # Merges all clients-placed values back together. return functools.reduce(lambda x, y: x + y, after_merge_results) else: return after_merge_results[0]
async def create_struct( self, elements: List[executor_value_base.ExecutorValue] ) -> executor_value_base.ExecutorValue: """Creates an embedded tuple of the given `elements`. Args: elements: A collection of embedded values. Returns: An instance of `executor_value_base.ExecutorValue` representing the embedded tuple. Raises: TypeError: If the `elements` are not embedded in the executor. """ element_values = [] element_types = [] for name, value in structure.iter_elements( structure.from_container(elements)): py_typecheck.check_type(value, executor_value_base.ExecutorValue) element_values.append((name, value.internal_representation)) if name is not None: element_types.append((name, value.type_signature)) else: element_types.append(value.type_signature) value = structure.Struct(element_values) type_signature = computation_types.StructType(element_types) return self._strategy.ingest_value(value, type_signature)
async def _zip(self, arg, placement, all_equal): self._check_arg_is_structure(arg) py_typecheck.check_type(placement, placement_literals.PlacementLiteral) self._check_strategy_compatible_with_placement(placement) children = self._target_executors[placement] cardinality = len(children) elements = structure.to_elements(arg.internal_representation) for _, v in elements: py_typecheck.check_type(v, list) if len(v) != cardinality: raise RuntimeError('Expected {} items, found {}.'.format( cardinality, len(v))) new_vals = [] for idx in range(cardinality): new_vals.append( structure.Struct([(k, v[idx]) for k, v in elements])) new_vals = await asyncio.gather( *[c.create_struct(x) for c, x in zip(children, new_vals)]) return FederatedResolvingStrategyValue( new_vals, computation_types.FederatedType(computation_types.StructType( ((k, v.member) if k else v.member for k, v in structure.iter_elements(arg.type_signature))), placement, all_equal=all_equal))
def _serialize_struct_type( struct_typed_value: Any, type_spec: computation_types.StructType) -> _SerializeReturnType: """Serializes a value of tuple type.""" type_elem_iter = structure.iter_elements(type_spec) val_elem_iter = structure.iter_elements( structure.from_container(struct_typed_value)) tup_elems = [] for (e_name, e_type), (_, e_val) in zip(type_elem_iter, val_elem_iter): e_proto, _ = serialize_value(e_val, e_type) tup_elems.append( executor_pb2.Value.Struct.Element(name=e_name if e_name else None, value=e_proto)) result_proto = (executor_pb2.Value(struct=executor_pb2.Value.Struct( element=tup_elems))) return result_proto, type_spec
def _remove_struct_element_names_from_tff_type(type_spec): """Removes names of struct elements from `type_spec`. Args: type_spec: An instance of `computation_types.Type` that must be a tensor, a (possibly) nested structure of tensors, or a function. Returns: A modified version of `type_spec` with element names in stuctures removed. Raises: TypeError: if arg is of the wrong type. """ if type_spec is None: return None if isinstance(type_spec, computation_types.FunctionType): return computation_types.FunctionType( _remove_struct_element_names_from_tff_type(type_spec.parameter), _remove_struct_element_names_from_tff_type(type_spec.result)) if isinstance(type_spec, computation_types.TensorType): return type_spec py_typecheck.check_type(type_spec, computation_types.StructType) return computation_types.StructType([ (None, _remove_struct_element_names_from_tff_type(v)) for _, v in structure.iter_elements(type_spec) ])
def _compute_summation_type_for_bitwidth(bitwidth, type_spec): """Creates a `tff.Type` with dtype based on bitwidth.""" def type_for_bitwidth_limited_tensor(bits, tensor_type): if bits < 1 or bits > MAXIMUM_SUPPORTED_BITWIDTH: raise ValueError( 'Encountered an bitwidth that cannot be handled: {b}. ' 'Extended bitwidth must be between [1,{m}].' '\nRequested: {r}'.format(b=bits, r=bitwidth, m=MAXIMUM_SUPPORTED_BITWIDTH)) elif bits < 32: return computation_types.TensorType( shape=tensor_type.shape, dtype=tf.uint32 if tensor_type.dtype.is_unsigned else tf.int32) else: return computation_types.TensorType( shape=tensor_type.shape, dtype=tf.uint64 if tensor_type.dtype.is_unsigned else tf.int64) if type_spec.is_tensor(): return type_for_bitwidth_limited_tensor(bitwidth, type_spec) elif type_spec.is_struct(): return computation_types.StructType( structure.iter_elements( structure.map_structure(type_for_bitwidth_limited_tensor, bitwidth, type_spec))) else: raise TypeError( 'Summation types can only be created from TensorType or ' 'StructType. Received a {t}'.format(t=type_spec))
def is_sum_compatible(type_spec: computation_types.Type) -> bool: """Determines if `type_spec` is a type that can be added to itself. Types that are sum-compatible are composed of scalars of numeric types, possibly packaged into nested named tuples, and possibly federated. Types that are sum-incompatible include sequences, functions, abstract types, and placements. Args: type_spec: A `computation_types.Type`. Returns: `True` iff `type_spec` is sum-compatible, `False` otherwise. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return is_numeric_dtype( type_spec.dtype) and type_spec.shape.is_fully_defined() elif type_spec.is_struct(): return all( is_sum_compatible(v) for _, v in structure.iter_elements(type_spec)) elif type_spec.is_federated(): return is_sum_compatible(type_spec.member) else: return False
def test_namedtuple_elements_two_tuples(self): elems = [tf.int32 for k in range(10)] t = computation_types.to_type(elems) self.assertIsInstance(t, computation_types.StructWithPythonType) self.assertIs(t.python_container, list) for k in structure.iter_elements(t): self.assertLen(k, 2)
async def compute(self): """Returns the result of computing the embedded value. Raises: TypeError: If the embedded value is composed of values that are not embedded in the executor. RuntimeError: If the embedded value is not a kind supported by the `FederatingExecutor`. """ if isinstance(self._value, executor_value_base.ExecutorValue): return await self._value.compute() elif isinstance(self._value, structure.Struct): results = await asyncio.gather(*[ FederatedResolvingStrategyValue(v, t).compute() for v, t in zip(self._value, self._type_signature) ]) element_types = structure.iter_elements(self._type_signature) return structure.Struct( (n, v) for (n, _), v in zip(element_types, results)) elif isinstance(self._value, list): py_typecheck.check_type(self._type_signature, computation_types.FederatedType) for value in self._value: py_typecheck.check_type(value, executor_value_base.ExecutorValue) if self._type_signature.all_equal: return await self._value[0].compute() else: return await asyncio.gather(*[v.compute() for v in self._value]) else: raise RuntimeError( 'Computing values of type {} represented as {} is not supported in ' 'this executor.'.format(self._type_signature, py_typecheck.type_string(type(self._value))))
def _unwrap(value): if isinstance(value, tf.Tensor): return value.numpy() elif isinstance(value, structure.Struct): return structure.Struct( (k, _unwrap(v)) for k, v in structure.iter_elements(value)) else: return value
def _pack_into_type(to_pack, type_spec): """Pack Tensor value `to_pack` into the nested structure `type_spec`.""" if type_spec.is_struct(): elem_iter = structure.iter_elements(type_spec) return structure.Struct([(elem_name, _pack_into_type(to_pack, elem_type)) for elem_name, elem_type in elem_iter]) elif type_spec.is_tensor(): return tf.broadcast_to(to_pack, type_spec.shape)
def _unwrap_tensors(self, value): if tf.is_tensor(value): return value.numpy() elif isinstance(value, structure.Struct): return structure.Struct((k, self._unwrap_tensors(v)) for k, v in structure.iter_elements(value)) else: return value
def _build(comp, scope): """Transforms `comp` to CDF, possibly adding bindings to `scope`.""" # The structure returned by this function is a generalized version of # call-dominant form. This function may result in the patterns specified in # the top-level function's docstring. if comp.is_reference(): return scope.resolve(comp.name) elif comp.is_selection(): source = _build(comp.source, scope) if source.is_struct(): return source[comp.as_index()] return building_blocks.Selection(source, index=comp.as_index()) elif comp.is_struct(): elements = [] for (name, value) in structure.iter_elements(comp): value = _build(value, scope) elements.append((name, value)) return building_blocks.Struct(elements) elif comp.is_call(): function = _build(comp.function, scope) argument = None if comp.argument is None else _build( comp.argument, scope) if function.is_lambda(): if argument is not None: scope = scope.new_child() scope.add_local(function.parameter_name, argument) return _build(function.result, scope) else: return scope.create_binding( building_blocks.Call(function, argument)) elif comp.is_lambda(): scope = scope.new_child_with_bindings() if comp.parameter_name: scope.add_local( comp.parameter_name, building_blocks.Reference(comp.parameter_name, comp.parameter_type)) result = _build(comp.result, scope) block = scope.bindings_to_block_with_result(result) return building_blocks.Lambda(comp.parameter_name, comp.parameter_type, block) elif comp.is_block(): scope = scope.new_child() for (name, value) in comp.locals: scope.add_local(name, _build(value, scope)) return _build(comp.result, scope) elif (comp.is_intrinsic() or comp.is_data() or comp.is_compiled_computation()): _disallow_higher_order(comp, global_comp) return comp elif comp.is_placement(): raise ValueError( f'Found placement {comp} in\n{global_comp}\n' 'but placements are not allowed in local computations.') else: raise ValueError( f'Unrecognized computation kind\n{comp}\nin\n{global_comp}')
def _create_result_tensor(type_spec, value): """Packs `value` into `type_spec` recursively.""" if type_spec.is_tensor(): type_spec.shape.assert_is_fully_defined() result = tf.constant(value, dtype=type_spec.dtype, shape=type_spec.shape) else: elements = [] if inferred_value_type.is_struct(): # Copy the leaf values according to the type_spec structure. for (name, elem_type), value in zip( structure.iter_elements(type_spec), value): elements.append((name, _create_result_tensor(elem_type, value))) else: # "Broadcast" the value to each level of the type_spec structure. for _, elem_type in structure.iter_elements(type_spec): elements.append((None, _create_result_tensor(elem_type, value))) result = structure.Struct(elements) return result
def _format_struct_type_members(struct_type: 'StructType') -> str: def _element_repr(element): name, value = element if name is not None: return '(\'{}\', {!r})'.format(name, value) return repr(value) return ', '.join( _element_repr(e) for e in structure.iter_elements(struct_type))
def _deserialize_type_spec(serialize_type_variable, python_container=None): """Deserialize a `tff.Type` protocol buffer into a python class instance.""" type_spec = type_serialization.deserialize_type( computation_pb2.Type.FromString( serialize_type_variable.read_value().numpy())) if type_spec.is_struct() and python_container is not None: type_spec = computation_types.StructWithPythonType( structure.iter_elements(type_spec), python_container) return type_conversions.type_to_tf_structure(type_spec)
def update_state(structure, **kwargs): """Constructs a new `structure` with new values for fields in `kwargs`. This is a helper method for working structured objects in a functional manner. This method will create a new structure where the fields named by keys in `kwargs` replaced with the associated values. NOTE: This method only works on the first level of `structure`, and does not recurse in the case of nested structures. A field that is itself a structure can be replaced with another structure. Args: structure: The structure with named fields to update. **kwargs: The list of key-value pairs of fields to update in `structure`. Returns: A new instance of the same type of `structure`, with the fields named in the keys of `**kwargs` replaced with the associated values. Raises: KeyError: If kwargs contains a field that is not in structure. TypeError: If structure is not a structure with named fields. """ if not (py_typecheck.is_named_tuple(structure) or py_typecheck.is_attrs(structure) or isinstance(structure, (struct_lib.Struct, collections.abc.Mapping))): raise TypeError( '`structure` must be a structure with named fields (e.g. ' 'dict, attrs class, collections.namedtuple, ' 'tff.structure.Struct), but found {}'.format(type(structure))) if isinstance(structure, struct_lib.Struct): elements = [(k, v) if k not in kwargs else (k, kwargs.pop(k)) for k, v in struct_lib.iter_elements(structure)] if kwargs: raise KeyError( f'`structure` does not contain fields named {kwargs}') return struct_lib.Struct(elements) elif py_typecheck.is_named_tuple(structure): # In Python 3.8 and later `_asdict` no longer return OrdereDict, rather a # regular `dict`, so we wrap here to get consistent types across Python # version.s d = collections.OrderedDict(structure._asdict()) elif py_typecheck.is_attrs(structure): d = attr.asdict(structure, dict_factory=collections.OrderedDict) else: for key in kwargs: if key not in structure: raise KeyError( 'structure does not contain a field named "{!s}"'.format( key)) d = structure d.update(kwargs) if isinstance(structure, collections.abc.Mapping): return d return type(structure)(**d)
def _partition_value( val: _PartitioningValue, type_signature: computation_types.Type) -> _PartitioningValue: """Partitions value as specified in _split_value_into_subrounds.""" if type_signature.is_struct(): struct_val = structure.from_container(val.payload) result_container = [] for (_, val_elem), (name, type_elem) in zip( structure.iter_elements(struct_val), structure.iter_elements(type_signature)): partitioning_val_elem = _PartitioningValue( val_elem, val.num_remaining_clients, val.num_remaining_partitions, val.last_client_index) partition_result = _partition_value(partitioning_val_elem, type_elem) result_container.append((name, partition_result.payload)) return _PartitioningValue(structure.Struct(result_container), partition_result.num_remaining_clients, partition_result.num_remaining_partitions, partition_result.last_client_index) elif (type_signature.is_federated() and type_signature.placement.is_clients()): if type_signature.all_equal: # In this case we simply replicate the argument for every subround. return val py_typecheck.check_type(val.payload, Sequence) num_clients_for_subround = math.ceil(val.num_remaining_clients / val.num_remaining_partitions) num_remaining_clients = val.num_remaining_clients - num_clients_for_subround num_remaining_partitions = val.num_remaining_partitions - 1 values_to_return = val.payload[val.last_client_index:val. last_client_index + num_clients_for_subround] last_client_index = val.last_client_index + num_clients_for_subround return _PartitioningValue( payload=values_to_return, num_remaining_clients=num_remaining_clients, num_remaining_partitions=num_remaining_partitions, last_client_index=last_client_index) else: return val
def _unwrap_execution_context_value(val): """Recursively removes wrapping from `val` under anonymous tuples.""" if isinstance(val, structure.Struct): value_elements_iter = structure.iter_elements(val) return structure.Struct((name, _unwrap_execution_context_value(elem)) for name, elem in value_elements_iter) elif isinstance(val, ExecutionContextValue): return _unwrap_execution_context_value(val.value) else: return val
async def _delegate(val, type_spec: computation_types.Type, target_executor: executor_base.Executor): """Delegates value representation to target executor. Args: val: A value representation to delegate. type_spec: The TFF type. target_executor: The target executor to delegate. Returns: An instance of `executor_value_base.ExecutorValue` owned by the target executor. """ py_typecheck.check_type(target_executor, executor_base.Executor) if val is None: return None if isinstance(val, executor_value_base.ExecutorValue): return val if isinstance(val, _Sequence): return await target_executor.create_value(await val.compute(), type_spec) if isinstance(val, structure.Struct): if len(val) != len(type_spec): raise ValueError( 'Found {} elements and {} types in a struct {}.'.format( len(val), len(type_spec), str(val))) elements = structure.iter_elements(val) element_types = structure.iter_elements(type_spec) names = [] coros = [] for (el_name, el), (el_type_name, el_type) in zip(elements, element_types): if el_name != el_type_name: raise ValueError( 'Element name mismatch between value ({}) and type ({}).'. format(str(val), str(type_spec))) names.append(el_name) coros.append(_delegate(el, el_type, target_executor)) flat_targets = await asyncio.gather(*coros) reassembled_struct = structure.Struct(list(zip(names, flat_targets))) return await target_executor.create_struct(reassembled_struct) return await target_executor.create_value(val, type_spec)
def _check_or_get_unbound_abstract_type_labels(type_spec, bound_labels, check): """Checks or collects abstract type labels from 'type_spec'. This is a helper function used by 'check_abstract_types_are_bound', not to be exported out of this module. Args: type_spec: An instance of computation_types.Type. bound_labels: A set of string labels that refer to 'bound' abstract types, i.e., ones that appear on the parameter side of a functional type. check: A bool value. If True, no new unbound type labels are permitted, and if False, any new labels encountered are returned as a set. Returns: If check is False, a set of new abstract type labels introduced in 'type_spec' that don't yet appear in the set 'bound_labels'. If check is True, always returns an empty set. Raises: TypeError: if unbound labels are found and check is True. """ py_typecheck.check_type(type_spec, computation_types.Type) if type_spec.is_tensor(): return set() elif type_spec.is_sequence(): return _check_or_get_unbound_abstract_type_labels( type_spec.element, bound_labels, check) elif type_spec.is_federated(): return _check_or_get_unbound_abstract_type_labels( type_spec.member, bound_labels, check) elif type_spec.is_struct(): return set().union(*[ _check_or_get_unbound_abstract_type_labels( v, bound_labels, check) for _, v in structure.iter_elements(type_spec) ]) elif type_spec.is_abstract(): if type_spec.label in bound_labels: return set() elif not check: return set([type_spec.label]) else: raise TypeError('Unbound type label \'{}\'.'.format( type_spec.label)) elif type_spec.is_function(): if type_spec.parameter is None: parameter_labels = set() else: parameter_labels = _check_or_get_unbound_abstract_type_labels( type_spec.parameter, bound_labels, False) result_labels = _check_or_get_unbound_abstract_type_labels( type_spec.result, bound_labels.union(parameter_labels), check) return parameter_labels.union(result_labels)