Ejemplo n.º 1
0
def _to_struct_internal_rep(
    *, value: Any, tf_function_cache: MutableMapping[str, Any],
    type_spec: computation_types.StructType,
    device: tf.config.LogicalDevice) -> structure.Struct:
  """Converts a python container to internal representation for TF executor."""
  type_iterator = structure.iter_elements(type_spec)
  value_struct = structure.from_container(value)
  value_iterator = structure.iter_elements(value_struct)

  if len(type_spec) != len(value_struct):
    raise TypeError('Mismatched number of elements between type spec and value '
                    'in `to_representation_for_type`. Type spec has {} '
                    'elements, value has {}.'.format(
                        len(type_spec), len(value_struct)))
  result_elem = []
  for (type_name, elem_type), (val_name,
                               elem_val) in zip(type_iterator, value_iterator):
    if val_name is not None and type_name != val_name:
      raise TypeError(
          'Mismatching element names in type vs. value: {} vs. {}.'.format(
              type_name, val_name))
    elem_repr = to_representation_for_type(elem_val, tf_function_cache,
                                           elem_type, device)
    result_elem.append((type_name, elem_repr))
  return structure.Struct(result_elem)
Ejemplo n.º 2
0
def is_valid_bitwidth_type_for_value_type(
    bitwidth_type: computation_types.Type,
    value_type: computation_types.Type) -> bool:
  """Whether or not `bitwidth_type` is a valid bitwidth type for `value_type`."""

  # NOTE: this function is primarily a helper for `intrinsic_factory.py`'s
  # `federated_secure_sum` function.
  py_typecheck.check_type(bitwidth_type, computation_types.Type)
  py_typecheck.check_type(value_type, computation_types.Type)

  if value_type.is_tensor() and bitwidth_type.is_tensor():
    # Here, `value_type` refers to a tensor. Rather than check that
    # `bitwidth_type` is exactly the same, we check that it is a single integer,
    # since we want a single bitwidth integer per tensor.
    return bitwidth_type.dtype.is_integer and (
        bitwidth_type.shape.num_elements() == 1)
  elif value_type.is_struct() and bitwidth_type.is_struct():
    bitwidth_name_and_types = list(structure.iter_elements(bitwidth_type))
    value_name_and_types = list(structure.iter_elements(value_type))
    if len(bitwidth_name_and_types) != len(value_name_and_types):
      return False
    for (inner_bitwidth_name,
         inner_bitwidth_type), (inner_value_name, inner_value_type) in zip(
             bitwidth_name_and_types, value_name_and_types):
      if inner_bitwidth_name != inner_value_name:
        return False
      if not is_valid_bitwidth_type_for_value_type(inner_bitwidth_type,
                                                   inner_value_type):
        return False
    return True
  else:
    return False
Ejemplo n.º 3
0
def is_valid_bitwidth_type_for_value_type(
    bitwidth_type: computation_types.Type,
    value_type: computation_types.Type) -> bool:
  """Whether or not `bitwidth_type` is a valid bitwidth type for `value_type`."""

  # NOTE: this function is primarily a helper for `intrinsic_factory.py`'s
  # `federated_secure_sum` function.
  py_typecheck.check_type(bitwidth_type, computation_types.Type)
  py_typecheck.check_type(value_type, computation_types.Type)

  if bitwidth_type.is_tensor():
    # This condition applies to both `value_type` being a tensor or structure,
    # as the same integer bitwidth can be used for all values in the structure.
    return bitwidth_type.dtype.is_integer and (
        bitwidth_type.shape.num_elements() == 1)
  elif value_type.is_struct() and bitwidth_type.is_struct():
    bitwidth_name_and_types = list(structure.iter_elements(bitwidth_type))
    value_name_and_types = list(structure.iter_elements(value_type))
    if len(bitwidth_name_and_types) != len(value_name_and_types):
      return False
    for (inner_bitwidth_name,
         inner_bitwidth_type), (inner_value_name, inner_value_type) in zip(
             bitwidth_name_and_types, value_name_and_types):
      if inner_bitwidth_name != inner_value_name:
        return False
      if not is_valid_bitwidth_type_for_value_type(inner_bitwidth_type,
                                                   inner_value_type):
        return False
    return True
  else:
    return False
Ejemplo n.º 4
0
def is_single_integer_or_matches_structure(
    type_sig: computation_types.Type,
    shape_type: computation_types.Type) -> bool:
  """If `type_sig` is an integer or integer structure matching `shape_type`."""

  py_typecheck.check_type(type_sig, computation_types.Type)
  py_typecheck.check_type(shape_type, computation_types.Type)

  if type_sig.is_tensor():
    # This condition applies to both `shape_type` being a tensor or structure,
    # as the same integer bitwidth can be used for all values in the structure.
    return type_sig.dtype.is_integer and (type_sig.shape.num_elements() == 1)
  elif shape_type.is_struct() and type_sig.is_struct():
    bitwidth_name_and_types = list(structure.iter_elements(type_sig))
    shape_name_and_types = list(structure.iter_elements(shape_type))
    if len(type_sig) != len(shape_name_and_types):
      return False
    for (inner_name, type_sig), (inner_shape_name, inner_shape_type) in zip(
        bitwidth_name_and_types, shape_name_and_types):
      if inner_name != inner_shape_name:
        return False
      if not is_single_integer_or_matches_structure(type_sig, inner_shape_type):
        return False
    return True
  else:
    return False
Ejemplo n.º 5
0
def _serialize_struct_type(
    struct_typed_value: Any,
    type_spec: computation_types.StructType,
) -> computation_types.StructType:
    """Serializes a value of tuple type."""
    value_structure = structure.from_container(struct_typed_value)
    if len(value_structure) != len(type_spec):
        raise TypeError(
            'Cannot serialize a struct value of '
            f'{len(value_structure)} elements to a struct type '
            f'requiring {len(type_spec)} elements. Trying to serialize'
            f'\n{struct_typed_value!r}\nto\n{type_spec}.')
    type_elem_iter = structure.iter_elements(type_spec)
    val_elem_iter = structure.iter_elements(value_structure)
    elements = []
    for (e_name, e_type), (_, e_val) in zip(type_elem_iter, val_elem_iter):
        e_value, _ = serialize_value(e_val, e_type)
        if e_name:
            element = executor_pb2.Value.Struct.Element(name=e_name,
                                                        value=e_value)
        else:
            element = executor_pb2.Value.Struct.Element(value=e_value)
        elements.append(element)
    value_proto = executor_pb2.Value(struct=executor_pb2.Value.Struct(
        element=elements))
    return value_proto, type_spec
Ejemplo n.º 6
0
 def from_tff_result(cls, struct):
     py_typecheck.check_type(struct, structure.Struct)
     return cls(
         [value for _, value in structure.iter_elements(struct.trainable)],
         [
             value
             for _, value in structure.iter_elements(struct.non_trainable)
         ])
 async def create_struct(self, elements):
     elements_as_structure = structure.from_container(elements)
     elements_iter = structure.iter_elements(elements_as_structure)
     pairs = ((n, v.internal_representation) for (n, v) in elements_iter)
     inner_elements = structure.Struct(pairs)
     return await self._delegate(
         self._target_executor.create_struct(inner_elements))
Ejemplo n.º 8
0
def _create_structure_of_coro_references(
    coro: Coroutine[Any, Any,
                    Any], type_signature: computation_types.Type) -> Any:
  """Returns a structure of `tff.program.CoroValueReference`s."""
  py_typecheck.check_type(type_signature, computation_types.Type)

  if type_signature.is_struct():

    async def _to_structure(coro: Coroutine[Any, Any, Any]) -> structure.Struct:
      return structure.from_container(await coro)

    coro = _to_structure(coro)
    shared_awaitable = async_utils.SharedAwaitable(coro)

    async def _get_item(awaitable: Awaitable[structure.Struct],
                        index: int) -> Any:
      value = await awaitable
      return value[index]

    elements = []
    element_types = structure.iter_elements(type_signature)
    for index, (name, element_type) in enumerate(element_types):
      element_coro = _get_item(shared_awaitable, index)
      element = _create_structure_of_coro_references(element_coro, element_type)
      elements.append((name, element))
    return structure.Struct(elements)
  elif (type_signature.is_federated() and
        type_signature.placement == placements.SERVER):
    return _create_structure_of_coro_references(coro, type_signature.member)
  elif type_signature.is_sequence():
    return CoroValueReference(coro, type_signature)
  elif type_signature.is_tensor():
    return CoroValueReference(coro, type_signature)
  else:
    raise NotImplementedError(f'Unexpected type found: {type_signature}.')
Ejemplo n.º 9
0
async def _materialize_structure_of_value_references(
    value: Any, type_signature: computation_types.Type) -> Any:
  """Returns a structure of materialized values."""
  py_typecheck.check_type(type_signature, computation_types.Type)

  async def _materialize(value: Any) -> Any:
    if isinstance(value, value_reference.MaterializableValueReference):
      return await value.get_value()
    else:
      return value

  if type_signature.is_struct():
    value = structure.from_container(value)
    element_types = list(structure.iter_elements(type_signature))
    element_coros = [
        _materialize_structure_of_value_references(v, t)
        for v, (_, t) in zip(value, element_types)
    ]
    elements = await asyncio.gather(*element_coros)
    elements = [(n, v) for v, (n, _) in zip(elements, element_types)]
    return structure.Struct(elements)
  elif (type_signature.is_federated() and
        type_signature.placement == placements.SERVER):
    return await _materialize_structure_of_value_references(
        value, type_signature.member)
  elif type_signature.is_sequence():
    return await _materialize(value)
  elif type_signature.is_tensor():
    return await _materialize(value)
  else:
    return value
Ejemplo n.º 10
0
def _repackage_partitioned_values(after_merge_results,
                                  result_type_spec: computation_types.Type):
    """Inverts `_split_value_into_subrounds` above."""
    py_typecheck.check_type(after_merge_results, list)
    if result_type_spec.is_struct():
        after_merge_structs = [
            structure.from_container(x) for x in after_merge_results
        ]
        result_container = []
        for idx, (name, elem_type) in enumerate(
                structure.iter_elements(result_type_spec)):
            result_container.append(
                (name,
                 _repackage_partitioned_values(
                     [x[idx] for x in after_merge_structs], elem_type)))
        return structure.Struct(result_container)
    elif result_type_spec.is_federated(
    ) and result_type_spec.placement.is_clients():
        if result_type_spec.all_equal:
            return after_merge_results[0]
        for x in after_merge_results:
            py_typecheck.check_type(x, (list, tuple))
        # Merges all clients-placed values back together.
        return functools.reduce(lambda x, y: x + y, after_merge_results)
    else:
        return after_merge_results[0]
Ejemplo n.º 11
0
  async def create_struct(
      self, elements: List[executor_value_base.ExecutorValue]
  ) -> executor_value_base.ExecutorValue:
    """Creates an embedded tuple of the given `elements`.

    Args:
      elements: A collection of embedded values.

    Returns:
      An instance of `executor_value_base.ExecutorValue` representing the
      embedded tuple.

    Raises:
      TypeError: If the `elements` are not embedded in the executor.
    """
    element_values = []
    element_types = []
    for name, value in structure.iter_elements(
        structure.from_container(elements)):
      py_typecheck.check_type(value, executor_value_base.ExecutorValue)
      element_values.append((name, value.internal_representation))
      if name is not None:
        element_types.append((name, value.type_signature))
      else:
        element_types.append(value.type_signature)
    value = structure.Struct(element_values)
    type_signature = computation_types.StructType(element_types)
    return self._strategy.ingest_value(value, type_signature)
Ejemplo n.º 12
0
 async def _zip(self, arg, placement, all_equal):
     self._check_arg_is_structure(arg)
     py_typecheck.check_type(placement, placement_literals.PlacementLiteral)
     self._check_strategy_compatible_with_placement(placement)
     children = self._target_executors[placement]
     cardinality = len(children)
     elements = structure.to_elements(arg.internal_representation)
     for _, v in elements:
         py_typecheck.check_type(v, list)
         if len(v) != cardinality:
             raise RuntimeError('Expected {} items, found {}.'.format(
                 cardinality, len(v)))
     new_vals = []
     for idx in range(cardinality):
         new_vals.append(
             structure.Struct([(k, v[idx]) for k, v in elements]))
     new_vals = await asyncio.gather(
         *[c.create_struct(x) for c, x in zip(children, new_vals)])
     return FederatedResolvingStrategyValue(
         new_vals,
         computation_types.FederatedType(computation_types.StructType(
             ((k, v.member) if k else v.member
              for k, v in structure.iter_elements(arg.type_signature))),
                                         placement,
                                         all_equal=all_equal))
Ejemplo n.º 13
0
def _serialize_struct_type(
        struct_typed_value: Any,
        type_spec: computation_types.StructType) -> _SerializeReturnType:
    """Serializes a value of tuple type."""
    type_elem_iter = structure.iter_elements(type_spec)
    val_elem_iter = structure.iter_elements(
        structure.from_container(struct_typed_value))
    tup_elems = []
    for (e_name, e_type), (_, e_val) in zip(type_elem_iter, val_elem_iter):
        e_proto, _ = serialize_value(e_val, e_type)
        tup_elems.append(
            executor_pb2.Value.Struct.Element(name=e_name if e_name else None,
                                              value=e_proto))
    result_proto = (executor_pb2.Value(struct=executor_pb2.Value.Struct(
        element=tup_elems)))
    return result_proto, type_spec
Ejemplo n.º 14
0
def _remove_struct_element_names_from_tff_type(type_spec):
    """Removes names of struct elements from `type_spec`.

  Args:
    type_spec: An instance of `computation_types.Type` that must be a tensor, a
      (possibly) nested structure of tensors, or a function.

  Returns:
    A modified version of `type_spec` with element names in stuctures removed.

  Raises:
    TypeError: if arg is of the wrong type.
  """
    if type_spec is None:
        return None
    if isinstance(type_spec, computation_types.FunctionType):
        return computation_types.FunctionType(
            _remove_struct_element_names_from_tff_type(type_spec.parameter),
            _remove_struct_element_names_from_tff_type(type_spec.result))
    if isinstance(type_spec, computation_types.TensorType):
        return type_spec
    py_typecheck.check_type(type_spec, computation_types.StructType)
    return computation_types.StructType([
        (None, _remove_struct_element_names_from_tff_type(v))
        for _, v in structure.iter_elements(type_spec)
    ])
Ejemplo n.º 15
0
def _compute_summation_type_for_bitwidth(bitwidth, type_spec):
    """Creates a `tff.Type` with dtype based on bitwidth."""
    def type_for_bitwidth_limited_tensor(bits, tensor_type):
        if bits < 1 or bits > MAXIMUM_SUPPORTED_BITWIDTH:
            raise ValueError(
                'Encountered an bitwidth that cannot be handled: {b}. '
                'Extended bitwidth must be between [1,{m}].'
                '\nRequested: {r}'.format(b=bits,
                                          r=bitwidth,
                                          m=MAXIMUM_SUPPORTED_BITWIDTH))
        elif bits < 32:
            return computation_types.TensorType(
                shape=tensor_type.shape,
                dtype=tf.uint32 if tensor_type.dtype.is_unsigned else tf.int32)
        else:
            return computation_types.TensorType(
                shape=tensor_type.shape,
                dtype=tf.uint64 if tensor_type.dtype.is_unsigned else tf.int64)

    if type_spec.is_tensor():
        return type_for_bitwidth_limited_tensor(bitwidth, type_spec)
    elif type_spec.is_struct():
        return computation_types.StructType(
            structure.iter_elements(
                structure.map_structure(type_for_bitwidth_limited_tensor,
                                        bitwidth, type_spec)))
    else:
        raise TypeError(
            'Summation types can only be created from TensorType or '
            'StructType. Received a {t}'.format(t=type_spec))
Ejemplo n.º 16
0
def is_sum_compatible(type_spec: computation_types.Type) -> bool:
    """Determines if `type_spec` is a type that can be added to itself.

  Types that are sum-compatible are composed of scalars of numeric types,
  possibly packaged into nested named tuples, and possibly federated. Types
  that are sum-incompatible include sequences, functions, abstract types,
  and placements.

  Args:
    type_spec: A `computation_types.Type`.

  Returns:
    `True` iff `type_spec` is sum-compatible, `False` otherwise.
  """
    py_typecheck.check_type(type_spec, computation_types.Type)
    if type_spec.is_tensor():
        return is_numeric_dtype(
            type_spec.dtype) and type_spec.shape.is_fully_defined()
    elif type_spec.is_struct():
        return all(
            is_sum_compatible(v)
            for _, v in structure.iter_elements(type_spec))
    elif type_spec.is_federated():
        return is_sum_compatible(type_spec.member)
    else:
        return False
Ejemplo n.º 17
0
 def test_namedtuple_elements_two_tuples(self):
     elems = [tf.int32 for k in range(10)]
     t = computation_types.to_type(elems)
     self.assertIsInstance(t, computation_types.StructWithPythonType)
     self.assertIs(t.python_container, list)
     for k in structure.iter_elements(t):
         self.assertLen(k, 2)
Ejemplo n.º 18
0
  async def compute(self):
    """Returns the result of computing the embedded value.

    Raises:
      TypeError: If the embedded value is composed of values that are not
        embedded in the executor.
      RuntimeError: If the embedded value is not a kind supported by the
        `FederatingExecutor`.
    """
    if isinstance(self._value, executor_value_base.ExecutorValue):
      return await self._value.compute()
    elif isinstance(self._value, structure.Struct):
      results = await asyncio.gather(*[
          FederatedResolvingStrategyValue(v, t).compute()
          for v, t in zip(self._value, self._type_signature)
      ])
      element_types = structure.iter_elements(self._type_signature)
      return structure.Struct(
          (n, v) for (n, _), v in zip(element_types, results))
    elif isinstance(self._value, list):
      py_typecheck.check_type(self._type_signature,
                              computation_types.FederatedType)
      for value in self._value:
        py_typecheck.check_type(value, executor_value_base.ExecutorValue)
      if self._type_signature.all_equal:
        return await self._value[0].compute()
      else:
        return await asyncio.gather(*[v.compute() for v in self._value])
    else:
      raise RuntimeError(
          'Computing values of type {} represented as {} is not supported in '
          'this executor.'.format(self._type_signature,
                                  py_typecheck.type_string(type(self._value))))
Ejemplo n.º 19
0
def _unwrap(value):
    if isinstance(value, tf.Tensor):
        return value.numpy()
    elif isinstance(value, structure.Struct):
        return structure.Struct(
            (k, _unwrap(v)) for k, v in structure.iter_elements(value))
    else:
        return value
Ejemplo n.º 20
0
 def _pack_into_type(to_pack, type_spec):
   """Pack Tensor value `to_pack` into the nested structure `type_spec`."""
   if type_spec.is_struct():
     elem_iter = structure.iter_elements(type_spec)
     return structure.Struct([(elem_name, _pack_into_type(to_pack, elem_type))
                              for elem_name, elem_type in elem_iter])
   elif type_spec.is_tensor():
     return tf.broadcast_to(to_pack, type_spec.shape)
Ejemplo n.º 21
0
 def _unwrap_tensors(self, value):
     if tf.is_tensor(value):
         return value.numpy()
     elif isinstance(value, structure.Struct):
         return structure.Struct((k, self._unwrap_tensors(v))
                                 for k, v in structure.iter_elements(value))
     else:
         return value
Ejemplo n.º 22
0
 def _build(comp, scope):
     """Transforms `comp` to CDF, possibly adding bindings to `scope`."""
     # The structure returned by this function is a generalized version of
     # call-dominant form. This function may result in the patterns specified in
     # the top-level function's docstring.
     if comp.is_reference():
         return scope.resolve(comp.name)
     elif comp.is_selection():
         source = _build(comp.source, scope)
         if source.is_struct():
             return source[comp.as_index()]
         return building_blocks.Selection(source, index=comp.as_index())
     elif comp.is_struct():
         elements = []
         for (name, value) in structure.iter_elements(comp):
             value = _build(value, scope)
             elements.append((name, value))
         return building_blocks.Struct(elements)
     elif comp.is_call():
         function = _build(comp.function, scope)
         argument = None if comp.argument is None else _build(
             comp.argument, scope)
         if function.is_lambda():
             if argument is not None:
                 scope = scope.new_child()
                 scope.add_local(function.parameter_name, argument)
             return _build(function.result, scope)
         else:
             return scope.create_binding(
                 building_blocks.Call(function, argument))
     elif comp.is_lambda():
         scope = scope.new_child_with_bindings()
         if comp.parameter_name:
             scope.add_local(
                 comp.parameter_name,
                 building_blocks.Reference(comp.parameter_name,
                                           comp.parameter_type))
         result = _build(comp.result, scope)
         block = scope.bindings_to_block_with_result(result)
         return building_blocks.Lambda(comp.parameter_name,
                                       comp.parameter_type, block)
     elif comp.is_block():
         scope = scope.new_child()
         for (name, value) in comp.locals:
             scope.add_local(name, _build(value, scope))
         return _build(comp.result, scope)
     elif (comp.is_intrinsic() or comp.is_data()
           or comp.is_compiled_computation()):
         _disallow_higher_order(comp, global_comp)
         return comp
     elif comp.is_placement():
         raise ValueError(
             f'Found placement {comp} in\n{global_comp}\n'
             'but placements are not allowed in local computations.')
     else:
         raise ValueError(
             f'Unrecognized computation kind\n{comp}\nin\n{global_comp}')
 def _create_result_tensor(type_spec, value):
   """Packs `value` into `type_spec` recursively."""
   if type_spec.is_tensor():
     type_spec.shape.assert_is_fully_defined()
     result = tf.constant(value, dtype=type_spec.dtype, shape=type_spec.shape)
   else:
     elements = []
     if inferred_value_type.is_struct():
       # Copy the leaf values according to the type_spec structure.
       for (name, elem_type), value in zip(
           structure.iter_elements(type_spec), value):
         elements.append((name, _create_result_tensor(elem_type, value)))
     else:
       # "Broadcast" the value to each level of the type_spec structure.
       for _, elem_type in structure.iter_elements(type_spec):
         elements.append((None, _create_result_tensor(elem_type, value)))
     result = structure.Struct(elements)
   return result
Ejemplo n.º 24
0
def _format_struct_type_members(struct_type: 'StructType') -> str:
    def _element_repr(element):
        name, value = element
        if name is not None:
            return '(\'{}\', {!r})'.format(name, value)
        return repr(value)

    return ', '.join(
        _element_repr(e) for e in structure.iter_elements(struct_type))
Ejemplo n.º 25
0
def _deserialize_type_spec(serialize_type_variable, python_container=None):
    """Deserialize a `tff.Type` protocol buffer into a python class instance."""
    type_spec = type_serialization.deserialize_type(
        computation_pb2.Type.FromString(
            serialize_type_variable.read_value().numpy()))
    if type_spec.is_struct() and python_container is not None:
        type_spec = computation_types.StructWithPythonType(
            structure.iter_elements(type_spec), python_container)
    return type_conversions.type_to_tf_structure(type_spec)
Ejemplo n.º 26
0
def update_state(structure, **kwargs):
    """Constructs a new `structure` with new values for fields in `kwargs`.

  This is a helper method for working structured objects in a functional manner.
  This method will create a new structure where the fields named by keys in
  `kwargs` replaced with the associated values.

  NOTE: This method only works on the first level of `structure`, and does not
  recurse in the case of nested structures. A field that is itself a structure
  can be replaced with another structure.

  Args:
    structure: The structure with named fields to update.
    **kwargs: The list of key-value pairs of fields to update in `structure`.

  Returns:
    A new instance of the same type of `structure`, with the fields named
    in the keys of `**kwargs` replaced with the associated values.

  Raises:
    KeyError: If kwargs contains a field that is not in structure.
    TypeError: If structure is not a structure with named fields.
  """
    if not (py_typecheck.is_named_tuple(structure)
            or py_typecheck.is_attrs(structure)
            or isinstance(structure,
                          (struct_lib.Struct, collections.abc.Mapping))):
        raise TypeError(
            '`structure` must be a structure with named fields (e.g. '
            'dict, attrs class, collections.namedtuple, '
            'tff.structure.Struct), but found {}'.format(type(structure)))
    if isinstance(structure, struct_lib.Struct):
        elements = [(k, v) if k not in kwargs else (k, kwargs.pop(k))
                    for k, v in struct_lib.iter_elements(structure)]
        if kwargs:
            raise KeyError(
                f'`structure` does not contain fields named {kwargs}')
        return struct_lib.Struct(elements)
    elif py_typecheck.is_named_tuple(structure):
        # In Python 3.8 and later `_asdict` no longer return OrdereDict, rather a
        # regular `dict`, so we wrap here to get consistent types across Python
        # version.s
        d = collections.OrderedDict(structure._asdict())
    elif py_typecheck.is_attrs(structure):
        d = attr.asdict(structure, dict_factory=collections.OrderedDict)
    else:
        for key in kwargs:
            if key not in structure:
                raise KeyError(
                    'structure does not contain a field named "{!s}"'.format(
                        key))
        d = structure
    d.update(kwargs)
    if isinstance(structure, collections.abc.Mapping):
        return d
    return type(structure)(**d)
Ejemplo n.º 27
0
def _partition_value(
        val: _PartitioningValue,
        type_signature: computation_types.Type) -> _PartitioningValue:
    """Partitions value as specified in _split_value_into_subrounds."""
    if type_signature.is_struct():
        struct_val = structure.from_container(val.payload)
        result_container = []
        for (_, val_elem), (name, type_elem) in zip(
                structure.iter_elements(struct_val),
                structure.iter_elements(type_signature)):
            partitioning_val_elem = _PartitioningValue(
                val_elem, val.num_remaining_clients,
                val.num_remaining_partitions, val.last_client_index)
            partition_result = _partition_value(partitioning_val_elem,
                                                type_elem)
            result_container.append((name, partition_result.payload))
        return _PartitioningValue(structure.Struct(result_container),
                                  partition_result.num_remaining_clients,
                                  partition_result.num_remaining_partitions,
                                  partition_result.last_client_index)
    elif (type_signature.is_federated()
          and type_signature.placement.is_clients()):

        if type_signature.all_equal:
            # In this case we simply replicate the argument for every subround.
            return val

        py_typecheck.check_type(val.payload, Sequence)
        num_clients_for_subround = math.ceil(val.num_remaining_clients /
                                             val.num_remaining_partitions)
        num_remaining_clients = val.num_remaining_clients - num_clients_for_subround
        num_remaining_partitions = val.num_remaining_partitions - 1
        values_to_return = val.payload[val.last_client_index:val.
                                       last_client_index +
                                       num_clients_for_subround]
        last_client_index = val.last_client_index + num_clients_for_subround
        return _PartitioningValue(
            payload=values_to_return,
            num_remaining_clients=num_remaining_clients,
            num_remaining_partitions=num_remaining_partitions,
            last_client_index=last_client_index)
    else:
        return val
Ejemplo n.º 28
0
def _unwrap_execution_context_value(val):
    """Recursively removes wrapping from `val` under anonymous tuples."""
    if isinstance(val, structure.Struct):
        value_elements_iter = structure.iter_elements(val)
        return structure.Struct((name, _unwrap_execution_context_value(elem))
                                for name, elem in value_elements_iter)
    elif isinstance(val, ExecutionContextValue):
        return _unwrap_execution_context_value(val.value)
    else:
        return val
Ejemplo n.º 29
0
async def _delegate(val, type_spec: computation_types.Type,
                    target_executor: executor_base.Executor):
    """Delegates value representation to target executor.

  Args:
    val: A value representation to delegate.
    type_spec: The TFF type.
    target_executor: The target executor to delegate.

  Returns:
    An instance of `executor_value_base.ExecutorValue` owned by the target
    executor.
  """
    py_typecheck.check_type(target_executor, executor_base.Executor)
    if val is None:
        return None
    if isinstance(val, executor_value_base.ExecutorValue):
        return val
    if isinstance(val, _Sequence):
        return await target_executor.create_value(await val.compute(),
                                                  type_spec)
    if isinstance(val, structure.Struct):
        if len(val) != len(type_spec):
            raise ValueError(
                'Found {} elements and {} types in a struct {}.'.format(
                    len(val), len(type_spec), str(val)))
        elements = structure.iter_elements(val)
        element_types = structure.iter_elements(type_spec)
        names = []
        coros = []
        for (el_name, el), (el_type_name,
                            el_type) in zip(elements, element_types):
            if el_name != el_type_name:
                raise ValueError(
                    'Element name mismatch between value ({}) and type ({}).'.
                    format(str(val), str(type_spec)))
            names.append(el_name)
            coros.append(_delegate(el, el_type, target_executor))
        flat_targets = await asyncio.gather(*coros)
        reassembled_struct = structure.Struct(list(zip(names, flat_targets)))
        return await target_executor.create_struct(reassembled_struct)
    return await target_executor.create_value(val, type_spec)
Ejemplo n.º 30
0
    def _check_or_get_unbound_abstract_type_labels(type_spec, bound_labels,
                                                   check):
        """Checks or collects abstract type labels from 'type_spec'.

    This is a helper function used by 'check_abstract_types_are_bound', not to
    be exported out of this module.

    Args:
      type_spec: An instance of computation_types.Type.
      bound_labels: A set of string labels that refer to 'bound' abstract types,
        i.e., ones that appear on the parameter side of a functional type.
      check: A bool value. If True, no new unbound type labels are permitted,
        and if False, any new labels encountered are returned as a set.

    Returns:
      If check is False, a set of new abstract type labels introduced in
      'type_spec' that don't yet appear in the set 'bound_labels'. If check is
      True, always returns an empty set.

    Raises:
      TypeError: if unbound labels are found and check is True.
    """
        py_typecheck.check_type(type_spec, computation_types.Type)
        if type_spec.is_tensor():
            return set()
        elif type_spec.is_sequence():
            return _check_or_get_unbound_abstract_type_labels(
                type_spec.element, bound_labels, check)
        elif type_spec.is_federated():
            return _check_or_get_unbound_abstract_type_labels(
                type_spec.member, bound_labels, check)
        elif type_spec.is_struct():
            return set().union(*[
                _check_or_get_unbound_abstract_type_labels(
                    v, bound_labels, check)
                for _, v in structure.iter_elements(type_spec)
            ])
        elif type_spec.is_abstract():
            if type_spec.label in bound_labels:
                return set()
            elif not check:
                return set([type_spec.label])
            else:
                raise TypeError('Unbound type label \'{}\'.'.format(
                    type_spec.label))
        elif type_spec.is_function():
            if type_spec.parameter is None:
                parameter_labels = set()
            else:
                parameter_labels = _check_or_get_unbound_abstract_type_labels(
                    type_spec.parameter, bound_labels, False)
            result_labels = _check_or_get_unbound_abstract_type_labels(
                type_spec.result, bound_labels.union(parameter_labels), check)
            return parameter_labels.union(result_labels)