Ejemplo n.º 1
0
    async def create_struct(self, elements):
        """Creates a tuple of `elements`.

    Args:
      elements: As documented in `executor_base.Executor`.

    Returns:
      An instance of `EagerValue` that represents the constructed tuple.
    """
        elements = structure.to_elements(structure.from_container(elements))
        val_elements = []
        type_elements = []
        for k, v in elements:
            py_typecheck.check_type(v, EagerValue)
            val_elements.append((k, v.internal_representation))
            type_elements.append((k, v.type_signature))
        return EagerValue(
            structure.Struct(val_elements), self._tf_function_cache,
            computation_types.StructType([(k, v) if k is not None else v
                                          for k, v in type_elements]))
Ejemplo n.º 2
0
 async def create_struct(self, elements):
     constructed_anon_tuple = structure.from_container(elements)
     proto_elem = []
     type_elem = []
     for k, v in structure.iter_elements(constructed_anon_tuple):
         py_typecheck.check_type(v, RemoteValue)
         proto_elem.append(
             executor_pb2.CreateStructRequest.Element(
                 name=(k if k else None), value_ref=v.value_ref))
         type_elem.append((k, v.type_signature) if k else v.type_signature)
     result_type = computation_types.StructType(type_elem)
     request = executor_pb2.CreateStructRequest(element=proto_elem)
     if self._bidi_stream is None:
         response = _request(self._stub.CreateStruct, request)
     else:
         response = (await self._bidi_stream.send_request(
             executor_pb2.ExecuteRequest(create_struct=request)
         )).create_struct
     py_typecheck.check_type(response, executor_pb2.CreateStructResponse)
     return RemoteValue(response.value_ref, result_type, self)
Ejemplo n.º 3
0
 async def create_value(self, value, type_spec=None):
     type_spec = computation_types.to_type(type_spec)
     if isinstance(value, computation_impl.ComputationImpl):
         return await self.create_value(
             computation_impl.ComputationImpl.get_proto(value),
             type_utils.reconcile_value_with_type_spec(value, type_spec))
     elif isinstance(value, pb.Computation):
         return await self._evaluate(value)
     elif type_spec is not None and type_spec.is_struct():
         v_el = structure.to_elements(structure.from_container(value))
         vals = await asyncio.gather(*[
             self.create_value(val, t)
             for (_, val), t in zip(v_el, type_spec)
         ])
         return ReferenceResolvingExecutorValue(
             structure.Struct(
                 (name, val) for (name, _), val in zip(v_el, vals)))
     else:
         return ReferenceResolvingExecutorValue(
             await self._target_executor.create_value(value, type_spec))
Ejemplo n.º 4
0
def infer_cardinalities(value, type_spec):
    """Infers cardinalities from Python `value`.

  Allows for any Python object to represent a federated value; enforcing
  particular representations is not the job of this inference function, but
  rather ingestion functions lower in the stack.

  Args:
    value: Python object from which to infer TFF placement cardinalities.
    type_spec: The TFF type spec for `value`, determining the semantics for
      inferring cardinalities. That is, we only pull the cardinality off of
      federated types.

  Returns:
    Dict of cardinalities.

  Raises:
    ValueError: If conflicting cardinalities are inferred from `value`.
    TypeError: If the arguments are of the wrong types, or if `type_spec` is
      a federated type which is not `all_equal` but the yet-to-be-embedded
      `value` is not represented as a Python `list`.
  """
    py_typecheck.check_not_none(value)
    py_typecheck.check_type(type_spec, computation_types.Type)
    if isinstance(value, cardinality_carrying_base.CardinalityCarrying):
        return value.cardinality
    if type_spec.is_federated():
        if type_spec.all_equal:
            return {}
        py_typecheck.check_type(value, collections.abc.Sized)
        return {type_spec.placement: len(value)}
    elif type_spec.is_struct():
        structure_value = structure.from_container(value, recursive=False)
        cardinality_dict = {}
        for idx, (_, elem_type) in enumerate(structure.to_elements(type_spec)):
            cardinality_dict = merge_cardinalities(
                cardinality_dict,
                infer_cardinalities(structure_value[idx], elem_type))
        return cardinality_dict
    else:
        return {}
Ejemplo n.º 5
0
def _to_struct_internal_rep(
    *, value: Any, tf_function_cache: MutableMapping[str, Any],
    type_spec: computation_types.StructType,
    device: tf.config.LogicalDevice) -> structure.Struct:
  """Converts a python container to internal representation for TF executor."""
  type_elem = structure.to_elements(type_spec)
  value_elem = (structure.to_elements(structure.from_container(value)))
  result_elem = []
  if len(type_elem) != len(value_elem):
    raise TypeError('Expected a {}-element tuple, found {} elements.'.format(
        len(type_elem), len(value_elem)))
  for (type_name, elem_type), (val_name,
                               elem_val) in zip(type_elem, value_elem):
    if type_name != val_name:
      raise TypeError(
          'Mismatching element names in type vs. value: {} vs. {}.'.format(
              type_name, val_name))
    elem_repr = to_representation_for_type(elem_val, tf_function_cache,
                                           elem_type, device)
    result_elem.append((type_name, elem_repr))
  return structure.Struct(result_elem)
Ejemplo n.º 6
0
 async def create_struct(self, elements):
   if not isinstance(elements, structure.Struct):
     elements = structure.from_container(elements)
   element_strings = []
   element_kv_pairs = structure.to_elements(elements)
   to_gather = []
   type_elements = []
   for k, v in element_kv_pairs:
     py_typecheck.check_type(v, CachedValue)
     to_gather.append(v.target_future)
     if k is not None:
       py_typecheck.check_type(k, str)
       element_strings.append('{}={}'.format(k, v.identifier))
       type_elements.append((k, v.type_signature))
     else:
       element_strings.append(str(v.identifier))
       type_elements.append(v.type_signature)
   type_spec = computation_types.StructType(type_elements)
   gathered = await asyncio.gather(*to_gather)
   identifier = CachedValueIdentifier('<{}>'.format(','.join(element_strings)))
   try:
     cached_value = self._cache[identifier]
   except KeyError:
     target_future = asyncio.ensure_future(
         self._target_executor.create_struct(
             structure.Struct(
                 (k, v) for (k, _), v in zip(element_kv_pairs, gathered))))
     cached_value = CachedValue(identifier, None, type_spec, target_future)
     self._cache[identifier] = cached_value
   try:
     target_value = await cached_value.target_future
   except Exception:
     # TODO(b/145514490): This is a bit heavy handed, there maybe caches where
     # only the current cache item needs to be invalidated; however this
     # currently only occurs when an inner RemoteExecutor has the backend go
     # down.
     self._cache = {}
     raise
   type_spec.check_assignable_from(target_value.type_signature)
   return cached_value
Ejemplo n.º 7
0
def _serialize_struct_type(
    value_proto: serialization_bindings.Value,
    struct_typed_value: Any,
    type_spec: computation_types.StructType,
) -> computation_types.StructType:
    """Serializes a value of tuple type."""
    value_structure = structure.from_container(struct_typed_value)
    if len(value_structure) != len(type_spec):
        raise TypeError(
            'Cannot serialize a struct value of '
            f'{len(value_structure)} elements to a struct type '
            f'requiring {len(type_spec)} elements. Trying to serialize'
            f'\n{struct_typed_value!r}\nto\n{type_spec}.')
    type_elem_iter = structure.iter_elements(type_spec)
    val_elem_iter = structure.iter_elements(value_structure)
    struct_proto = value_proto.struct
    for (e_name, e_type), (_, e_val) in zip(type_elem_iter, val_elem_iter):
        element_proto = struct_proto.element.add()
        serialize_value(e_val, e_type, value_proto=element_proto.value)
        if e_name:
            element_proto.name = e_name
    return value_proto, type_spec
Ejemplo n.º 8
0
 async def _encrypt_values_on_singleton(self, val, sender, receiver):
     ###
     # we can safely assume  sender has cardinality=1 when receiver is CLIENTS
     ###
     # Case 1: receiver=CLIENTS
     #     plaintext: Fed(Tensor, sender, all_equal=True)
     #     pk_receiver: Fed(Tuple(Tensor), sender, all_equal=True)
     #     sk_sender: Fed(Tensor, sender, all_equal=True)
     #   Returns:
     #     encrypted_values: Tuple(Fed(Tensor, sender, all_equal=True))
     ###
     ### Check proper key placement
     sk_sender = self.key_references.get_secret_key(sender)
     pk_receiver = self.key_references.get_public_key(receiver)
     type_analysis.check_federated_type(sk_sender.type_signature,
                                        placement=sender)
     assert sk_sender.type_signature.placement is sender
     assert pk_receiver.type_signature.placement is sender
     ### Check placement cardinalities
     rcv_children = self.strategy._get_child_executors(receiver)
     snd_children = self.strategy._get_child_executors(sender)
     py_typecheck.check_len(snd_children, 1)
     snd_child = snd_children[0]
     ### Check value cardinalities
     type_analysis.check_federated_type(val.type_signature,
                                        placement=sender)
     py_typecheck.check_len(val.internal_representation, 1)
     py_typecheck.check_type(pk_receiver.type_signature.member,
                             tff.StructType)
     py_typecheck.check_len(pk_receiver.internal_representation,
                            len(rcv_children))
     py_typecheck.check_len(sk_sender.internal_representation, 1)
     ### Materialize encryptor function definition & type spec
     input_type = val.type_signature.member
     self._input_type_cache = input_type
     pk_rcv_type = pk_receiver.type_signature.member
     sk_snd_type = sk_sender.type_signature.member
     pk_element_type = pk_rcv_type[0]
     encryptor_arg_spec = (input_type, pk_element_type, sk_snd_type)
     encryptor_proto, encryptor_type = utils.materialize_computation_from_cache(
         sodium_comp.make_encryptor, self._encryptor_cache,
         encryptor_arg_spec)
     ### Prepare encryption arguments
     v = val.internal_representation[0]
     sk = sk_sender.internal_representation[0]
     ### Encrypt values and return them
     encryptor_fn = await snd_child.create_value(encryptor_proto,
                                                 encryptor_type)
     encryptor_args = await asyncio.gather(*[
         snd_child.create_struct([v, this_pk, sk])
         for this_pk in pk_receiver.internal_representation
     ])
     encrypted_values = await asyncio.gather(*[
         snd_child.create_call(encryptor_fn, arg) for arg in encryptor_args
     ])
     encrypted_value_types = [encryptor_type.result] * len(encrypted_values)
     return federated_resolving_strategy.FederatedResolvingStrategyValue(
         structure.from_container(encrypted_values),
         tff.StructType([
             tff.FederatedType(evt, sender, all_equal=False)
             for evt in encrypted_value_types
         ]))
Ejemplo n.º 9
0
def serialize_jax_computation(traced_fn, arg_fn, parameter_type,
                              context_stack):
    """Serializes a Python function containing JAX code as a TFF computation.

  Args:
    traced_fn: The Python function containing JAX code to be traced by JAX and
      serialized as a TFF computation containing XLA code.
    arg_fn: An unpacking function that takes a TFF argument, and returns a combo
      of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed
      by `function_utils.create_argument_unpacking_fn`).
    parameter_type: An instance of `computation_types.Type` that represents the
      TFF type of the computation parameter, or `None` if the function does not
      take any parameters.
    context_stack: The context stack to use during serialization.

  Returns:
    An instance of `pb.Computation` with the constructed computation.

  Raises:
    TypeError: if the arguments are of the wrong types.
  """
    py_typecheck.check_callable(traced_fn)
    py_typecheck.check_callable(arg_fn)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)

    if parameter_type is not None:
        parameter_type = computation_types.to_type(parameter_type)
        packed_arg = _tff_type_to_xla_serializer_arg(parameter_type)
    else:
        packed_arg = None

    args, kwargs = arg_fn(packed_arg)

    # While the fake parameters are fed via args/kwargs during serialization,
    # it is possible for them to get reordered in the actual generated XLA code.
    # We use here the same flattening function as that one, which is used by
    # the JAX serializer to determine the ordering and allow it to be captured
    # in the parameter binding. We do not need to do anything special for the
    # results, since the results, if multiple, are always returned as a tuple.
    flattened_obj, _ = jax.tree_util.tree_flatten((args, kwargs))
    tensor_indexes = list(np.argsort([x.tensor_index for x in flattened_obj]))

    context = jax_computation_context.JaxComputationContext()
    with context_stack.install(context):
        tracer_callable = jax.xla_computation(traced_fn,
                                              tuple_args=True,
                                              return_shape=True)
        compiled_xla, returned_shape = tracer_callable(*args, **kwargs)

    if isinstance(returned_shape, jax.ShapeDtypeStruct):
        returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor(
            returned_shape)
    else:
        returned_type_spec = computation_types.to_type(
            structure.map_structure(
                _jax_shape_dtype_struct_to_tff_tensor,
                structure.from_container(returned_shape, recursive=True)))

    computation_type = computation_types.FunctionType(parameter_type,
                                                      returned_type_spec)
    return xla_serialization.create_xla_tff_computation(
        compiled_xla, tensor_indexes, computation_type)
 async def create_struct(self, elements):
   return ReferenceResolvingExecutorValue(structure.from_container(elements))
def create_constant(
        value, type_spec: computation_types.Type) -> ComputationProtoAndType:
    """Returns a tensorflow computation returning a constant `value`.

  The returned computation has the type signature `( -> T)`, where `T` is
  `type_spec`.

  `value` must be a value convertible to a tensor or a structure of values, such
  that the dtype and shapes match `type_spec`. `type_spec` must contain only
  named tuples and tensor types, but these can be arbitrarily nested.

  Args:
    value: A value to embed as a constant in the tensorflow graph.
    type_spec: A `computation_types.Type` to use as the argument to the
      constructed binary operator; must contain only named tuples and tensor
      types.

  Raises:
    TypeError: If the constraints of `type_spec` are violated.
  """
    if not type_analysis.is_generic_op_compatible_type(type_spec):
        raise TypeError(
            'Type spec {} cannot be constructed as a TensorFlow constant in TFF; '
            ' only nested tuples and tensors are permitted.'.format(type_spec))
    inferred_value_type = type_conversions.infer_type(value)
    if (inferred_value_type.is_struct()
            and not type_spec.is_assignable_from(inferred_value_type)):
        raise TypeError(
            'Must pass a only tensor or structure of tensor values to '
            '`create_tensorflow_constant`; encountered a value {v} with inferred '
            'type {t!r}, but needed {s!r}'.format(v=value,
                                                  t=inferred_value_type,
                                                  s=type_spec))
    if inferred_value_type.is_struct():
        value = structure.from_container(value, recursive=True)
    tensor_dtypes_in_type_spec = []

    def _pack_dtypes(type_signature):
        """Appends dtype of `type_signature` to nonlocal variable."""
        if type_signature.is_tensor():
            tensor_dtypes_in_type_spec.append(type_signature.dtype)
        return type_signature, False

    type_transformations.transform_type_postorder(type_spec, _pack_dtypes)

    if (any(x.is_integer for x in tensor_dtypes_in_type_spec)
            and (inferred_value_type.is_tensor()
                 and not inferred_value_type.dtype.is_integer)):
        raise TypeError(
            'Only integers can be used as scalar values if our desired constant '
            'type spec contains any integer tensors; passed scalar {} of dtype {} '
            'for type spec {}.'.format(value, inferred_value_type.dtype,
                                       type_spec))

    result_type = type_spec

    def _create_result_tensor(type_spec, value):
        """Packs `value` into `type_spec` recursively."""
        if type_spec.is_tensor():
            type_spec.shape.assert_is_fully_defined()
            result = tf.constant(value,
                                 dtype=type_spec.dtype,
                                 shape=type_spec.shape)
        else:
            elements = []
            if inferred_value_type.is_struct():
                # Copy the leaf values according to the type_spec structure.
                for (name, elem_type), value in zip(
                        structure.iter_elements(type_spec), value):
                    elements.append(
                        (name, _create_result_tensor(elem_type, value)))
            else:
                # "Broadcast" the value to each level of the type_spec structure.
                for _, elem_type in structure.iter_elements(type_spec):
                    elements.append(
                        (None, _create_result_tensor(elem_type, value)))
            result = structure.Struct(elements)
        return result

    with tf.Graph().as_default() as graph:
        result = _create_result_tensor(result_type, value)
        _, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(None, result_type)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=None,
                               result=result_binding)
    return _tensorflow_comp(tensorflow, type_signature)
Ejemplo n.º 12
0
 def test_from_container_with_namedtuple(self):
   x = structure.from_container(collections.namedtuple('_', 'x y')(1, 2))
   self.assertIsInstance(x, structure.Struct)
   self.assertEqual(str(x), '<x=1,y=2>')
Ejemplo n.º 13
0
 def test_from_container_with_dict(self):
   x = structure.from_container({'z': 10, 'y': 20, 'a': 30})
   self.assertIsInstance(x, structure.Struct)
   self.assertEqual(str(x), '<a=30,y=20,z=10>')
Ejemplo n.º 14
0
def to_representation_for_type(
        value: Any,
        tf_function_cache: MutableMapping[str, Any],
        type_spec: Optional[computation_types.Type] = None,
        device: Optional[tf.config.LogicalDevice] = None) -> Any:
    """Verifies or converts the `value` to an eager object matching `type_spec`.

  WARNING: This function is only partially implemented. It does not support
  data sets at this point.

  The output of this function is always an eager tensor, eager dataset, a
  representation of a TensorFlow computation, or a nested structure of those
  that matches `type_spec`, and when `device` has been specified, everything
  is placed on that device on a best-effort basis.

  TensorFlow computations are represented here as zero- or one-argument Python
  callables that accept their entire argument bundle as a single Python object.

  Args:
    value: The raw representation of a value to compare against `type_spec` and
      potentially to be converted.
    tf_function_cache: A cache obeying `dict` semantics that can be used to look
      up previously embedded TensorFlow functions.
    type_spec: An instance of `tff.Type`, can be `None` for values that derive
      from `typed_object.TypedObject`.
    device: An optional `tf.config.LogicalDevice` to place the value on (for
      tensor-level values).

  Returns:
    Either `value` itself, or a modified version of it.

  Raises:
    TypeError: If the `value` is not compatible with `type_spec`.
  """
    type_spec = type_utils.reconcile_value_with_type_spec(value, type_spec)
    if isinstance(value, computation_base.Computation):
        return to_representation_for_type(
            computation_impl.ComputationImpl.get_proto(value),
            tf_function_cache, type_spec, device)
    elif isinstance(value, pb.Computation):
        key = (value.SerializeToString(), str(type_spec),
               device.name if device else None)
        cached_fn = tf_function_cache.get(key)
        if cached_fn is not None:
            return cached_fn
        embedded_fn = embed_tensorflow_computation(value, type_spec, device)
        tf_function_cache[key] = embedded_fn
        return embedded_fn
    elif type_spec.is_struct():
        type_elem = structure.to_elements(type_spec)
        value_elem = (structure.to_elements(structure.from_container(value)))
        result_elem = []
        if len(type_elem) != len(value_elem):
            raise TypeError(
                'Expected a {}-element tuple, found {} elements.'.format(
                    len(type_elem), len(value_elem)))
        for (t_name, el_type), (v_name, el_val) in zip(type_elem, value_elem):
            if t_name != v_name:
                raise TypeError(
                    'Mismatching element names in type vs. value: {} vs. {}.'.
                    format(t_name, v_name))
            el_repr = to_representation_for_type(el_val, tf_function_cache,
                                                 el_type, device)
            result_elem.append((t_name, el_repr))
        return structure.Struct(result_elem)
    elif device is not None:
        py_typecheck.check_type(device, tf.config.LogicalDevice)
        with tf.device(device.name):
            return to_representation_for_type(value,
                                              tf_function_cache,
                                              type_spec=type_spec,
                                              device=None)
    elif isinstance(value, EagerValue):
        return value.internal_representation
    elif isinstance(value, executor_value_base.ExecutorValue):
        raise TypeError(
            'Cannot accept a value embedded within a non-eager executor.')
    elif type_spec.is_tensor():
        if not tf.is_tensor(value):
            value = tf.convert_to_tensor(value, dtype=type_spec.dtype)
        elif hasattr(value, 'read_value'):
            # a tf.Variable-like result, get a proper tensor.
            value = value.read_value()
        value_type = (computation_types.TensorType(value.dtype.base_dtype,
                                                   value.shape))
        if not type_spec.is_assignable_from(value_type):
            raise TypeError(
                'The apparent type {} of a tensor {} does not match the expected '
                'type {}.'.format(value_type, value, type_spec))
        return value
    elif type_spec.is_sequence():
        if isinstance(value, list):
            value = tensorflow_utils.make_data_set_from_elements(
                None, value, type_spec.element)
        py_typecheck.check_type(
            value, type_conversions.TF_DATASET_REPRESENTATION_TYPES)
        element_type = computation_types.to_type(value.element_spec)
        value_type = computation_types.SequenceType(element_type)
        type_spec.check_assignable_from(value_type)
        return value
    else:
        raise TypeError('Unexpected type {}.'.format(type_spec))
Ejemplo n.º 15
0
 async def _to_structure(coro: Coroutine[Any, Any, Any]) -> structure.Struct:
   return structure.from_container(await coro)
Ejemplo n.º 16
0
 def test_from_container_with_int(self):
   with self.assertRaises(TypeError):
     structure.from_container(10)
Ejemplo n.º 17
0
 def test_from_container_with_tuple(self):
   x = structure.from_container(tuple([10, 20]))
   self.assertIsInstance(x, structure.Struct)
   self.assertEqual(str(x), '<10,20>')
Ejemplo n.º 18
0
 def test_from_container_sparse_tensor(self):
   x = structure.from_container(
       tf.SparseTensor(indices=[[1]], values=[2], dense_shape=[5]))
   self.assertEqual(str(x), '<indices=[[1]],values=[2],dense_shape=[5]>')
Ejemplo n.º 19
0
 def test_from_container_with_ordered_dict(self):
   x = structure.from_container(
       collections.OrderedDict([('z', 10), ('y', 20), ('a', 30)]))
   self.assertIsInstance(x, structure.Struct)
   self.assertEqual(str(x), '<z=10,y=20,a=30>')
Ejemplo n.º 20
0
 def test_from_container_asdict_roundtrip(self, dict_in):
   structure_repr = structure.from_container(dict_in, recursive=True)
   dict_out = structure_repr._asdict(recursive=True)
   self.assertEqual(dict_in, dict_out)
Ejemplo n.º 21
0
 def test_from_container_with_struct(self):
   x = structure.from_container(structure.Struct([('a', 10), ('b', 20)]))
   self.assertIs(x, x)
Ejemplo n.º 22
0
 def test_from_container_raises_on_non_container_argument(self):
   with self.assertRaises(TypeError):
     structure.from_container(3)
Ejemplo n.º 23
0
 def test_to_odict_or_tuple_from_container_roundtrip(self, original):
   structure_repr = structure.from_container(original, recursive=True)
   out = structure.to_odict_or_tuple(structure_repr)
   self.assertEqual(original, out)
Ejemplo n.º 24
0
 def test_to_odict_or_tuple_empty_dict_becomes_empty_tuple(self):
   s = collections.OrderedDict()
   x = structure.from_container(s)
   self.assertEqual(structure.to_odict_or_tuple(x), ())
def _ensure_structure(obj):
  return structure.from_container(obj, True)
Ejemplo n.º 26
0
 def test_to_odict_or_tuple_mixed_nonrecursive(self):
   s = ODict(a=1, b=2, c=(3, ODict(d=4, e=5)))
   x = structure.from_container(s, recursive=False)
   self.assertEqual(s, structure.to_odict_or_tuple(x, recursive=False))
Ejemplo n.º 27
0
 def result_computation(arg):
     if parameter_type.is_struct():
         arg = structure.from_container(arg, recursive=True)
     return _evaluate_to_tensorflow(to_evaluate, {parameter_name: arg})
Ejemplo n.º 28
0
    async def create_value(self, value, type_spec=None):
        """Creates a value in this executor.

    The following kinds of `value` are supported as the input:

    * An instance of TFF computation proto containing one of the supported
      sequence intrinsics as its sole body.

    * An instance of eager TF dataset.

    * Anything that is supported by the target executor (as a pass-through).

    * A nested structure of any of the above.

    Args:
      value: The input for which to create a value.
      type_spec: An optional TFF type (required if `value` is not an instance of
        `typed_object.TypedObject`, otherwise it can be `None`).

    Returns:
      An instance of `SequenceExecutorValue` that represents the embedded value.
    """
        if type_spec is None:
            py_typecheck.check_type(value, typed_object.TypedObject)
            type_spec = value.type_signature
        else:
            type_spec = computation_types.to_type(type_spec)
        if isinstance(type_spec, computation_types.SequenceType):
            return SequenceExecutorValue(
                _SequenceFromPayload(value, type_spec), type_spec)
        if isinstance(value, pb.Computation):
            value_type = type_serialization.deserialize_type(value.type)
            value_type.check_equivalent_to(type_spec)
            which_computation = value.WhichOneof('computation')
            # NOTE: If not a supported type of intrinsic, we let it fall through and
            # be handled by embedding in the target executor (below).
            if which_computation == 'intrinsic':
                intrinsic_def = intrinsic_defs.uri_to_intrinsic_def(
                    value.intrinsic.uri)
                if intrinsic_def is None:
                    raise ValueError(
                        'Encountered an unrecognized intrinsic "{}".'.format(
                            value.intrinsic.uri))
                op_type = SequenceExecutor._SUPPORTED_INTRINSIC_TO_SEQUENCE_OP.get(
                    intrinsic_def.uri)
                if op_type is not None:
                    type_analysis.check_concrete_instance_of(
                        type_spec, intrinsic_def.type_signature)
                    op = op_type(type_spec)
                    return SequenceExecutorValue(op, type_spec)
        if isinstance(type_spec, computation_types.StructType):
            if not isinstance(value, structure.Struct):
                value = structure.from_container(value)
            elements = structure.flatten(value)
            element_types = structure.flatten(type_spec)
            flat_embedded_vals = await asyncio.gather(*[
                self.create_value(el, el_type)
                for el, el_type in zip(elements, element_types)
            ])
            embedded_struct = structure.pack_sequence_as(
                value, flat_embedded_vals)
            return await self.create_struct(embedded_struct)
        target_value = await self._target_executor.create_value(
            value, type_spec)
        return SequenceExecutorValue(target_value, type_spec)
Ejemplo n.º 29
0
class FederatingExecutorCreateCallTest(executor_test_utils.AsyncTestCase,
                                       parameterized.TestCase,
                                       test_case.TestCase):
    @parameterized.named_parameters(
        (
            'scalar_sum_less_than_mask',
            [10, 11, 12],
            10,  # larger than sum and individual client values.
            33,  # 33 mod (2**(10+2))
        ),
        (
            'scalar_sum_less_than_extended_mask',
            [10, 11, 12],
            4,  # larger than indivudal client values, but smaller than sum.
            33,  # 33 mod (2**(4+2))
        ),
        (
            'scalar_sum_greater_than_extend_mask',
            [10, 11, 12],
            2,  # smaller than individual client values.
            1,  # 33 mod (2**(2+2))
        ),
        (
            'structured_scalar_sum_less_than_mask',
            [[10, 0], [11, 1], [12, 2]],
            [10, 10],
            structure.from_container([33, 3]),
        ),
        (
            'structued_scalar_sum_greater_than_mask',
            [[10, 0], [11, 1], [12, 2]],
            [3, 3],
            structure.from_container([1, 3]),
        ),
        (
            'named_structured_scalar_sum_less_than_mask',
            [
                collections.OrderedDict(a=10, b=0),
                collections.OrderedDict(a=11, b=1),
                collections.OrderedDict(a=12, b=2),
            ],
            collections.OrderedDict(a=10, b=10),
            structure.from_container(collections.OrderedDict(a=33, b=3)),
        ),
        (
            'named_structued_scalar_sum_greater_than_mask',
            [
                collections.OrderedDict(a=10, b=0),
                collections.OrderedDict(a=11, b=1),
                collections.OrderedDict(a=12, b=2),
            ],
            collections.OrderedDict(a=3, b=3),
            structure.from_container(collections.OrderedDict(a=1, b=3)),
        ),
        (
            'nested_structured_tensor_sum',
            [
                collections.OrderedDict(
                    a=100, b=[tf.constant([0, 1, 2]),
                              tf.constant([10, 70])]),
                collections.OrderedDict(
                    a=1000, b=[tf.constant([1, 2, 3]),
                               tf.constant([20, 80])]),
                collections.OrderedDict(
                    a=10000, b=[tf.constant([2, 3, 4]),
                                tf.constant([30, 90])]),
            ],
            collections.OrderedDict(a=16, b=[4, 5]),
            structure.from_container(
                collections.OrderedDict(a=11100,
                                        b=structure.from_container([
                                            tf.constant([3, 6, 9]),
                                            tf.constant([60, 240 &
                                                         (2**7 - 1)]),
                                        ]))),
        ),
    )
    def test_returns_value_with_intrinsic_def_federated_secure_sum(
            self, client_values, bitwidth, expected_result):
        executor = create_test_executor()
        value_type = computation_types.at_clients(
            type_conversions.infer_type(client_values[0]))
        bitwidth_type = type_conversions.infer_type(bitwidth)
        comp, comp_type = create_intrinsic_def_federated_secure_sum(
            value_type.member, bitwidth_type)

        comp = self.run_sync(executor.create_value(comp, comp_type))
        arg_1 = self.run_sync(executor.create_value(client_values, value_type))
        arg_2 = self.run_sync(executor.create_value(bitwidth, bitwidth_type))
        args = self.run_sync(executor.create_struct([arg_1, arg_2]))
        result = self.run_sync(executor.create_call(comp, args))

        self.assertIsInstance(result, executor_value_base.ExecutorValue)
        self.assert_types_identical(result.type_signature, comp_type.result)
        actual_result = self.run_sync(result.compute())
        if isinstance(expected_result, structure.Struct):
            structure.map_structure(self.assertAllEqual, actual_result,
                                    expected_result)
        else:
            self.assertEqual(actual_result, expected_result)
Ejemplo n.º 30
0
 def test_from_container_with_namedtuple_of_odict_recursive(self):
     x = structure.from_container(collections.namedtuple('_', 'x y')(
         collections.OrderedDict([('a', 10), ('b', 20)]),
         collections.OrderedDict([('c', 30), ('d', 40)])),
                                  recursive=True)
     self.assertEqual(str(x), '<x=<a=10,b=20>,y=<c=30,d=40>>')