def assert_less_equal_max_and_add(summation_and_max_input, summand):
            summation, original_max_input = summation_and_max_input
            if max_input_type.is_struct():
                max_input = original_max_input
            else:
                # Broadcast max_input to the same structure as the summand.
                max_input = structure.map_structure(
                    lambda *args: original_max_input, summand)
            # Assert that all coordinates in all tensors are less than the secure sum
            # allowed max input value.
            def assert_all_coordinates_less_equal(x, m):
                return tf.Assert(
                    tf.reduce_all(
                        tf.less_equal(tf.cast(x, tf.int64),
                                      tf.cast(m, tf.int64))),
                    [
                        'client value larger than maximum specified for secure sum',
                        x, 'not less than or equal to', m
                    ])

            assert_ops = structure.flatten(
                structure.map_structure(assert_all_coordinates_less_equal,
                                        summand, max_input))
            with tf.control_dependencies(assert_ops):
                return structure.map_structure(tf.add, summation,
                                               summand), original_max_input
 def zeros_fn():
   if member_type.is_struct():
     structure.map_structure(lambda v: _validate_dtype_is_numeric(v.dtype),
                             member_type)
     return structure.map_structure(
         lambda v: tf.fill(v.shape, value=initial_value_fn(v)), member_type)
   _validate_dtype_is_numeric(member_type.dtype)
   return tf.fill(member_type.shape, value=initial_value_fn(member_type))
Beispiel #3
0
 def test_map_structure_tensors(self):
   x = tf.constant(1)
   y = tf.constant(2)
   self.assertAllEqual(structure.map_structure(tf.add, x, y), 3)
   x = tf.strings.bytes_split('abc')
   y = tf.strings.bytes_split('xyz')
   self.assertAllEqual(
       structure.map_structure(tf.add, x, y), ['ax', 'by', 'cz'])
 async def create_struct(self, elements):
   target_val = await self._target.create_struct(
       structure.map_structure(lambda x: x.value, elements))
   wrapped_val = TracingExecutorValue(self, self._get_new_value_index(),
                                      target_val)
   self._trace.append(
       ('create_struct', structure.map_structure(lambda x: x.index,
                                                 elements), wrapped_val.index))
   return wrapped_val
Beispiel #5
0
 def test_map_structure_fails_different_structures(self):
   x = structure.Struct.named(a=10, c=20)
   y = structure.Struct.named(a=30)
   with self.assertRaises(TypeError):
     structure.map_structure(tf.add, x, y)
   x = structure.Struct.named(a=10)
   y = structure.Struct.named(a=30, c=tf.strings.bytes_split('abc'))
   with self.assertRaises(TypeError):
     structure.map_structure(tf.add, x, y)
Beispiel #6
0
 def test_map_structure_tensor_fails(self):
   x = structure.Struct.named(a=10, c=20)
   y = tf.constant(2)
   with self.assertRaises(TypeError):
     structure.map_structure(tf.add, x, y)
   x = structure.Struct.named(a='abc', c='xyz')
   y = tf.strings.bytes_split('abc')
   with self.assertRaises(TypeError):
     structure.map_structure(tf.add, x, y)
class CreateUnaryOperatorTest(parameterized.TestCase, tf.test.TestCase):
    @parameterized.named_parameters(
        ('abs_int', tf.math.abs, _TensorType(tf.int32), [-1], 1),
        ('abs_float', tf.math.abs, _TensorType(tf.float32), [-1.0], 1.0),
        ('abs_unnamed_tuple',
         lambda x: structure.map_structure(tf.math.abs, x),
         _StructType(
             [_TensorType(tf.int32, [2]),
              _TensorType(tf.float32, [2])]), [[-1, -2], [-3.0, -4.0]],
         structure.Struct([(None, [1, 2]), (None, [3.0, 4.0])])),
        ('abs_named_tuple', lambda x: structure.map_structure(tf.math.abs, x),
         _StructType([('a', _TensorType(tf.int32, [2])),
                      ('b', _TensorType(tf.float32, [2]))]), [
                          [-1, -2], [-3.0, -4.0]
                      ], structure.Struct([('a', [1, 2]), ('b', [3.0, 4.0])])),
        ('reduce_sum_int', tf.math.reduce_sum, _TensorType(tf.int32,
                                                           [2]), [2, 2], 4),
        ('reduce_sum_float', tf.math.reduce_sum, _TensorType(
            tf.float32, [2]), [2.0, 2.5], 4.5),
        ('log_inf', tf.math.log, _TensorType(tf.float32), [0.0], -np.inf),
    )
    # pyformat: enable
    def test_returns_computation(self, operator, operand_type, operand,
                                 expected_result):
        proto, _ = tensorflow_computation_factory.create_unary_operator(
            operator, operand_type)

        self.assertIsInstance(proto, pb.Computation)
        actual_type = type_serialization.deserialize_type(proto.type)
        self.assertIsInstance(actual_type, computation_types.FunctionType)
        # Note: It is only useful to test the parameter type; the result type
        # depends on the `operator` used, not the implemenation
        # `create_unary_operator`.
        expected_parameter_type = operand_type
        self.assertEqual(actual_type.parameter, expected_parameter_type)
        actual_result = test_utils.run_tensorflow(proto, operand)
        self.assertAllEqual(actual_result, expected_result)

    @parameterized.named_parameters(
        ('non_callable_operator', 1, _TensorType(tf.int32)),
        ('none_type', tf.math.add, None),
        ('federated_type', tf.math.add, computation_types.at_server(tf.int32)),
        ('sequence_type', tf.math.add, computation_types.SequenceType(
            tf.int32)),
    )
    def test_raises_type_error(self, operator, type_signature):

        with self.assertRaises(TypeError):
            tensorflow_computation_factory.create_unary_operator(
                operator, type_signature)
Beispiel #8
0
 def test_input_spec_struct(self):
   keras_model = model_examples.build_linear_regression_keras_functional_model(
       feature_dims=1)
   input_spec = computation_types.StructType(
       collections.OrderedDict(
           x=tf.TensorSpec(shape=[None, 1], dtype=tf.float32),
           y=tf.TensorSpec(shape=[None, 1], dtype=tf.float32)))
   tff_model = keras_utils.from_keras_model(
       keras_model=keras_model,
       input_spec=input_spec,
       loss=tf.keras.losses.MeanSquaredError())
   self.assertIsInstance(tff_model, model_utils.EnhancedModel)
   self.assertIsInstance(tff_model.input_spec, structure.Struct)
   structure.map_structure(lambda x: self.assertIsInstance(x, tf.TensorSpec),
                           tff_model.input_spec)
Beispiel #9
0
  def test_map_structure(self):
    x = structure.Struct.named(
        a=10,
        b=structure.Struct.named(
            x=structure.Struct.named(p=40),
            y=30,
            z=structure.Struct.named(q=50, r=60)),
        c=20)
    y = structure.Struct.named(
        a=1,
        b=structure.Struct.named(
            x=structure.Struct.named(p=4),
            y=3,
            z=structure.Struct.named(q=5, r=6)),
        c=2)

    add = lambda v1, v2: v1 + v2
    self.assertEqual(
        structure.map_structure(add, x, y),
        structure.Struct.named(
            a=11,
            b=structure.Struct.named(
                x=structure.Struct.named(p=44),
                y=33,
                z=structure.Struct.named(q=55, r=66)),
            c=22))
def identity(source):
    """Applies `tf.identity` pointwise to `source`.

  This utility function provides the exact same behavior as `tf.identity`, but
  it generalizes to a wider class of objects, including ordinary tensors,
  variables, as well as various types of nested structures. It would typically
  be used together with `tf.control_dependencies` in non-eager TensorFlow.

  Args:
    source: A nested structure composed of tensors or variables embedded in
      containers that are compatible with `tf.nest`, or instances of
      `structure.Struct`. Elements that represent variables have
      their content extracted prior to identity mapping by first invoking
      `tf.Variable.read_value`.

  Returns:
    The result of applying `tf.identity` to read all elements of the `source`
    pointwise, with the same structure as `source`.

  Raises:
    TypeError: If types mismatch.
  """
    def _mapping_fn(x):
        if not tf.is_tensor(x):
            raise TypeError('Expected a tensor, found {}.'.format(
                py_typecheck.type_string(type(x))))
        if hasattr(x, 'read_value'):
            x = x.read_value()
        return tf.identity(x)

    # TODO(b/113112108): Extend this to containers of mixed types.
    if isinstance(source, structure.Struct):
        return structure.map_structure(_mapping_fn, source)
    else:
        return tf.nest.map_structure(_mapping_fn, source)
Beispiel #11
0
def serialize_jax_computation(traced_fn, arg_fn, parameter_type, context_stack):
  """Serializes a Python function containing JAX code as a TFF computation.

  Args:
    traced_fn: The Python function containing JAX code to be traced by JAX and
      serialized as a TFF computation containing XLA code.
    arg_fn: An unpacking function that takes a TFF argument, and returns a combo
      of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed
      by `function_utils.create_argument_unpacking_fn`).
    parameter_type: An instance of `computation_types.Type` that represents the
      TFF type of the computation parameter, or `None` if the function does not
      take any parameters.
    context_stack: The context stack to use during serialization.

  Returns:
    An instance of `pb.Computation` with the constructed computation.

  Raises:
    TypeError: if the arguments are of the wrong types.
  """
  py_typecheck.check_callable(traced_fn)
  py_typecheck.check_callable(arg_fn)
  py_typecheck.check_type(context_stack, context_stack_base.ContextStack)

  if parameter_type is not None:
    parameter_type = computation_types.to_type(parameter_type)
    packed_arg = _tff_type_to_xla_serializer_arg(parameter_type)
  else:
    packed_arg = None

  args, kwargs = arg_fn(packed_arg)

  # While the fake parameters are fed via args/kwargs during serialization,
  # it is possible for them to get reordered in the actual generated XLA code.
  # We use here the same flattening function as that one, which is used by
  # the JAX serializer to determine the ordering and allow it to be captured
  # in the parameter binding. We do not need to do anything special for the
  # results, since the results, if multiple, are always returned as a tuple.
  flattened_obj, _ = jax.tree_util.tree_flatten((args, kwargs))
  tensor_indexes = list(np.argsort([x.tensor_index for x in flattened_obj]))

  context = jax_computation_context.JaxComputationContext()
  with context_stack.install(context):
    tracer_callable = jax.xla_computation(
        traced_fn, tuple_args=True, return_shape=True)
    compiled_xla, returned_shape = tracer_callable(*args, **kwargs)

  if isinstance(returned_shape, jax.ShapeDtypeStruct):
    returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor(returned_shape)
  else:
    returned_type_spec = computation_types.to_type(
        structure.map_structure(
            _jax_shape_dtype_struct_to_tff_tensor,
            structure.from_container(returned_shape, recursive=True)))

  computation_type = computation_types.FunctionType(parameter_type,
                                                    returned_type_spec)
  return xla_serialization.create_xla_tff_computation(compiled_xla,
                                                      tensor_indexes,
                                                      computation_type)
def _get_accumulator_type(member_type):
  """Constructs a `tff.Type` for the accumulator in sample aggregation.

  Args:
    member_type: A `tff.Type` representing the member components of the
      federated type.

  Returns:
    The `tff.StructType` associated with the accumulator. The tuple contains
    two parts, `accumulators` and `rands`, that are parallel lists (e.g. the
    i-th index in one corresponds to the i-th index in the other). These two
    lists are used to sample from the accumulators with equal probability.
  """
  # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed.
  if member_type.is_struct():
    a = structure.map_structure(
        lambda v: computation_types.TensorType(v.dtype, [None] + v.shape.dims),
        member_type)
    return computation_types.StructType(
        collections.OrderedDict({
            'accumulators':
                computation_types.StructType(structure.to_odict(a, True)),
            'rands':
                computation_types.TensorType(tf.float32, shape=[None]),
        }))
  return computation_types.StructType(
      collections.OrderedDict({
          'accumulators':
              computation_types.TensorType(
                  member_type.dtype, shape=[None] + member_type.shape.dims),
          'rands':
              computation_types.TensorType(tf.float32, shape=[None]),
      }))
Beispiel #13
0
def _remove_batch_dim(
    type_spec: computation_types.Type) -> computation_types.Type:
  """Removes the batch dimension from the `tff.TensorType`s in `type_spec`.

  Args:
    type_spec: A `tff.Type` containing `tff.TensorType`s as leaves. The first
      dimension in the leaf `tff.TensorType` is the batch dimension.

  Returns:
    A `tff.Type` of the same structure as `type_spec`, with no batch dimensions
    in all the leaf `tff.TensorType`s.

  Raises:
    TypeError: If the argument has the wrong type.
    ValueError: If the `tff.TensorType` does not have the first dimension.
  """

  def _remove_first_dim_in_tensortype(tensor_type):
    """Return a new `tff.TensorType` after removing the first dimension."""
    py_typecheck.check_type(tensor_type, computation_types.TensorType)
    if (tensor_type.shape.rank is not None) and (tensor_type.shape.rank >= 1):
      return computation_types.TensorType(
          shape=tensor_type.shape[1:], dtype=tensor_type.dtype)
    else:
      raise ValueError('Provided shape must have rank 1 or higher.')

  return structure.map_structure(_remove_first_dim_in_tensortype, type_spec)
def assign(target, source):
    """Creates an op that assigns `target` from `source`.

  This utility function provides the exact same behavior as
  `tf.Variable.assign`, but it generalizes to a wider class of objects,
  including ordinary variables as well as various types of nested structures.

  Args:
    target: A nested structure composed of variables embedded in containers that
      are compatible with `tf.nest`, or instances of
      `structure.Struct`.
    source: A nsested structure composed of tensors, matching that of `target`.

  Returns:
    A single op that represents the assignment.

  Raises:
    TypeError: If types mismatch.
  """
    # TODO(b/113112108): Extend this to containers of mixed types.
    if isinstance(target, structure.Struct):
        return tf.group(*structure.flatten(
            structure.map_structure(lambda a, b: a.assign(b), target, source)))
    else:
        return tf.group(*tf.nest.flatten(
            tf.nest.map_structure(lambda a, b: a.assign(b), target, source)))
def _compute_summation_type_for_bitwidth(bitwidth, type_spec):
    """Creates a `tff.Type` with dtype based on bitwidth."""
    def type_for_bitwidth_limited_tensor(bits, tensor_type):
        if bits < 1 or bits > MAXIMUM_SUPPORTED_BITWIDTH:
            raise ValueError(
                'Encountered an bitwidth that cannot be handled: {b}. '
                'Extended bitwidth must be between [1,{m}].'
                '\nRequested: {r}'.format(b=bits,
                                          r=bitwidth,
                                          m=MAXIMUM_SUPPORTED_BITWIDTH))
        elif bits < 32:
            return computation_types.TensorType(
                shape=tensor_type.shape,
                dtype=tf.uint32 if tensor_type.dtype.is_unsigned else tf.int32)
        else:
            return computation_types.TensorType(
                shape=tensor_type.shape,
                dtype=tf.uint64 if tensor_type.dtype.is_unsigned else tf.int64)

    if type_spec.is_tensor():
        return type_for_bitwidth_limited_tensor(bitwidth, type_spec)
    elif type_spec.is_struct():
        return computation_types.StructType(
            structure.iter_elements(
                structure.map_structure(type_for_bitwidth_limited_tensor,
                                        bitwidth, type_spec)))
    else:
        raise TypeError(
            'Summation types can only be created from TensorType or '
            'StructType. Received a {t}'.format(t=type_spec))
Beispiel #16
0
  def test_map_structure(self):
    x = structure.Struct([
        ('a', 10),
        ('b',
         structure.Struct([
             ('x', structure.Struct([('p', 40)])),
             ('y', 30),
             ('z', structure.Struct([('q', 50), ('r', 60)])),
         ])),
        ('c', 20),
    ])
    y = structure.Struct([
        ('a', 1),
        ('b',
         structure.Struct([
             ('x', structure.Struct([('p', 4)])),
             ('y', 3),
             ('z', structure.Struct([('q', 5), ('r', 6)])),
         ])),
        ('c', 2),
    ])

    self.assertEqual(
        structure.map_structure(lambda x, y: x + y, x, y),
        structure.Struct([
            ('a', 11),
            ('b',
             structure.Struct([
                 ('x', structure.Struct([('p', 44)])),
                 ('y', 33),
                 ('z', structure.Struct([('q', 55), ('r', 66)])),
             ])),
            ('c', 22),
        ]))
Beispiel #17
0
 def federated_secure_sum(self, value, bitwidth):
   """Implements `federated_secure_sum` as defined in `api/intrinsics.py`."""
   value = value_impl.to_value(value, None, self._context_stack)
   value = value_utils.ensure_federated_value(value,
                                              placement_literals.CLIENTS,
                                              'value to be summed')
   type_analysis.check_is_structure_of_integers(value.type_signature)
   bitwidth_value = value_impl.to_value(bitwidth, None, self._context_stack)
   value_member_type = value.type_signature.member
   bitwidth_type = bitwidth_value.type_signature
   if not type_analysis.is_valid_bitwidth_type_for_value_type(
       bitwidth_type, value_member_type):
     raise TypeError(
         'Expected `federated_secure_sum` parameter `bitwidth` to match '
         'the structure of `value`, with one integer bitwidth per tensor in '
         '`value`. Found `value` of `{}` and `bitwidth` of `{}`.'.format(
             value_member_type, bitwidth_type))
   if bitwidth_type.is_tensor() and value_member_type.is_struct():
     bitwidth_value = value_impl.to_value(
         structure.map_structure(lambda _: bitwidth, value_member_type), None,
         self._context_stack)
   value = value_impl.ValueImpl.get_comp(value)
   bitwidth_value = value_impl.ValueImpl.get_comp(bitwidth_value)
   comp = building_block_factory.create_federated_secure_sum(
       value, bitwidth_value)
   comp = self._bind_comp_as_reference(comp)
   return value_impl.ValueImpl(comp, self._context_stack)
Beispiel #18
0
def serialize_jax_computation(traced_fn, arg_fn, parameter_type,
                              context_stack):
    """Serializes a Python function containing JAX code as a TFF computation.

  Args:
    traced_fn: The Python function containing JAX code to be traced by JAX and
      serialized as a TFF computation containing XLA code.
    arg_fn: An unpacking function that takes a TFF argument, and returns a combo
      of (args, kwargs) to invoke `traced_fn` with (e.g., as the one constructed
      by `function_utils.create_argument_unpacking_fn`).
    parameter_type: An instance of `computation_types.Type` that represents the
      TFF type of the computation parameter, or `None` if the function does not
      take any parameters.
    context_stack: The context stack to use during serialization.

  Returns:
    An instance of `pb.Computation` with the constructed computation.

  Raises:
    TypeError: if the arguments are of the wrong types.
  """
    py_typecheck.check_callable(traced_fn)
    py_typecheck.check_callable(arg_fn)
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)

    if parameter_type is not None:
        parameter_type = computation_types.to_type(parameter_type)
        packed_arg = _tff_type_to_xla_serializer_arg(parameter_type)
    else:
        packed_arg = None

    args, kwargs = arg_fn(packed_arg)

    def _adjust_arg(x):
        return type_conversions.type_to_py_container(x, x.type_signature)

    args = [_adjust_arg(x) for x in args]
    kwargs = {k: _adjust_arg(v) for k, v in kwargs.items()}

    context = jax_computation_context.JaxComputationContext()
    with context_stack.install(context):
        tracer_callable = jax.xla_computation(traced_fn,
                                              tuple_args=True,
                                              return_shape=True)
        compiled_xla, returned_shape = tracer_callable(*args, **kwargs)

    if isinstance(returned_shape, jax.ShapeDtypeStruct):
        returned_type_spec = _jax_shape_dtype_struct_to_tff_tensor(
            returned_shape)
    else:
        returned_type_spec = computation_types.to_type(
            structure.map_structure(
                _jax_shape_dtype_struct_to_tff_tensor,
                structure.from_container(returned_shape, recursive=True)))

    computation_type = computation_types.FunctionType(parameter_type,
                                                      returned_type_spec)
    return xla_serialization.create_xla_tff_computation(
        compiled_xla, computation_type)
def _ensure_structure(int_or_structure, int_or_structure_type,
                      possible_struct_type):
  if int_or_structure_type.is_struct() or not possible_struct_type.is_struct():
    return int_or_structure
  else:
    # Broadcast int_or_structure to the same structure as the struct type
    return structure.map_structure(lambda *args: int_or_structure,
                                   possible_struct_type)
 def apply_sampling(accumulators, rands):
   size = tf.shape(rands)[0]
   k = tf.minimum(size, max_num_samples)
   indices = tf.math.top_k(rands, k=k).indices
   # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed.
   if member_type.is_struct():
     return structure.map_structure(lambda v: fed_gather(v, indices),
                                    accumulators), fed_gather(rands, indices)
   return fed_gather(accumulators, indices), fed_gather(rands, indices)
 def merge(a, b):
   """Merges accumulators through concatenation."""
   # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed.
   if accumulator_type.is_struct():
     samples = structure.map_structure(fed_concat, _ensure_structure(a),
                                       _ensure_structure(b))
   else:
     samples = fed_concat(a, b)
   accumulators, rands = apply_sampling(samples.accumulators, samples.rands)
   return _Samples(accumulators, rands)
 def accumlator_type_fn():
   """Gets the type for the accumulators."""
   # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed.
   if member_type.is_struct():
     a = structure.map_structure(
         lambda v: tf.zeros([0] + v.shape.dims, v.dtype), member_type)
     return _Samples(structure.to_odict(a, True), tf.zeros([0], tf.float32))
   if member_type.shape:
     s = [0] + member_type.shape.dims
   return _Samples(tf.zeros(s, member_type.dtype), tf.zeros([0], tf.float32))
Beispiel #23
0
def _validate_value_type_and_encoders(value_type, encoders, encoder_type):
  """Validates if `value_type` and `encoders` are compatible."""
  if isinstance(encoders, _ALLOWED_ENCODERS):
    # If `encoders` is not a container, then `value_type` should be an instance
    # of `tff.TensorType.`
    if not isinstance(value_type, computation_types.TensorType):
      raise ValueError(
          '`value_type` and `encoders` do not have the same structure.')

    _validate_encoder(encoders, value_type, encoder_type)
  else:
    # If `encoders` is a container, then `value_type` should be an instance of
    # `tff.StructType.`
    if not isinstance(value_type, computation_types.StructType):
      raise TypeError('`value_type` is not compatible with the expected input '
                      'of the `encoders`.')
    structure.map_structure(lambda e, v: _validate_encoder(e, v, encoder_type),
                            structure.from_container(encoders, recursive=True),
                            value_type)
Beispiel #24
0
def build_zero_argument(parameter_type):
    if parameter_type is None:
        return None
    elif parameter_type.is_struct():
        return structure.map_structure(build_zero_argument, parameter_type)
    elif parameter_type == tffint32:
        return 0
    elif parameter_type == tffstring:
        return ''
    else:
        raise NotImplementedError(f'Unsupported type: {parameter_type}')
Beispiel #25
0
def to_representation_for_type(value, type_spec, backend=None):
    """Verifies or converts the `value` to executor payload matching `type_spec`.

  The following kinds of `value` are supported:

  * Computations, either `pb.Computation` or `computation_impl.ComputationImpl`.

  * Numpy arrays and scalars, or Python scalars that are converted to Numpy.

  * Nested structures of the above.

  Args:
    value: The raw representation of a value to compare against `type_spec` and
      potentially to be converted.
    type_spec: An instance of `tff.Type`. Can be `None` for values that derive
      from `typed_object.TypedObject`.
    backend: The backend to use; an instance of `xla_client.Client`. Only used
      for functional types. Can be `None` if unused.

  Returns:
    Either `value` itself, or a modified version of it.

  Raises:
    TypeError: If the `value` is not compatible with `type_spec`.
    ValueError: If the arguments are incorrect.
  """
    if backend is not None:
        py_typecheck.check_type(backend, xla_client.Client)
    if type_spec is not None:
        type_spec = computation_types.to_type(type_spec)
    type_spec = type_utils.reconcile_value_with_type_spec(value, type_spec)
    if isinstance(value, computation_base.Computation):
        return to_representation_for_type(
            computation_impl.ComputationImpl.get_proto(value), type_spec,
            backend)
    if isinstance(value, pb.Computation):
        comp_type = type_serialization.deserialize_type(value.type)
        if type_spec is not None:
            comp_type.check_equivalent_to(type_spec)
        return _ComputationCallable(value, comp_type, backend)
    if isinstance(type_spec, computation_types.StructType):
        return structure.map_structure(
            lambda v, t: to_representation_for_type(v, t, backend),
            structure.from_container(value, recursive=True), type_spec)
    if isinstance(type_spec, computation_types.TensorType):
        type_spec.shape.assert_is_fully_defined()
        type_analysis.check_type(value, type_spec)
        if type_spec.shape.rank == 0:
            return np.dtype(type_spec.dtype.as_numpy_dtype).type(value)
        if type_spec.shape.rank > 0:
            return np.array(value, dtype=type_spec.dtype.as_numpy_dtype)
        raise TypeError('Unsupported tensor shape {}.'.format(type_spec.shape))
    raise TypeError('Unexpected type {}.'.format(type_spec))
  def accumulate(current, value):
    """Accumulates samples through concatenation."""
    rands = fed_concat_expand_dims(current.rands, tf.random.uniform(shape=()))
    # TODO(b/121288403): Special-casing anonymous tuple shouldn't be needed.
    if member_type.is_struct():
      accumulators = structure.map_structure(
          fed_concat_expand_dims, _ensure_structure(current.accumulators),
          _ensure_structure(value))
    else:
      accumulators = fed_concat_expand_dims(current.accumulators, value)

    accumulators, rands = apply_sampling(accumulators, rands)
    return _Samples(accumulators, rands)
    def test_returns_value_with_intrinsic_def_federated_secure_sum(
            self, client_values, bitwidth, expected_result):
        executor = create_test_executor()
        value_type = computation_types.at_clients(
            type_conversions.infer_type(client_values[0]))
        bitwidth_type = type_conversions.infer_type(bitwidth)
        comp, comp_type = create_intrinsic_def_federated_secure_sum(
            value_type.member, bitwidth_type)

        comp = self.run_sync(executor.create_value(comp, comp_type))
        arg_1 = self.run_sync(executor.create_value(client_values, value_type))
        arg_2 = self.run_sync(executor.create_value(bitwidth, bitwidth_type))
        args = self.run_sync(executor.create_struct([arg_1, arg_2]))
        result = self.run_sync(executor.create_call(comp, args))

        self.assertIsInstance(result, executor_value_base.ExecutorValue)
        self.assert_types_identical(result.type_signature, comp_type.result)
        actual_result = self.run_sync(result.compute())
        if isinstance(expected_result, structure.Struct):
            structure.map_structure(self.assertAllEqual, actual_result,
                                    expected_result)
        else:
            self.assertEqual(actual_result, expected_result)
Beispiel #28
0
    def test_with_tuple_of_unnamed_elements(self):
        ex, _ = _make_executor_and_tracer_for_test()
        loop = asyncio.get_event_loop()

        v1 = loop.run_until_complete(ex.create_value(10, tf.int32))
        self.assertEqual(str(v1.identifier), '1')
        v2 = loop.run_until_complete(ex.create_value(11, tf.int32))
        self.assertEqual(str(v2.identifier), '2')
        v3 = loop.run_until_complete(ex.create_struct([v1, v2]))
        self.assertEqual(str(v3.identifier), '<1,2>')
        v4 = loop.run_until_complete(ex.create_struct((v1, v2)))
        self.assertIs(v4, v3)
        c4 = loop.run_until_complete(v4.compute())
        self.assertEqual(str(structure.map_structure(lambda x: x.numpy(), c4)),
                         '<10,11>')
        def max_input_from_bitwidth(bitwidth):
            # Secure sum is performed with int64, which has 63 bits, and we need at
            # least one bit to hold the summation of two client values.
            max_secure_sum_bitwidth = 62

            def compute_max_input(bits):
                assert_op = tf.Assert(
                    tf.less_equal(bits, max_secure_sum_bitwidth), [
                        bits,
                        f'is greater than maximum bitwidth {max_secure_sum_bitwidth}'
                    ])
                with tf.control_dependencies([assert_op]):
                    return tf.math.pow(tf.constant(2, tf.int64),
                                       tf.cast(bits, tf.int64)) - 1

            return structure.map_structure(compute_max_input, bitwidth)
Beispiel #30
0
    def initialize():
        # Allow fixed seeds, otherwise set a sentinel that signals a seed should be
        # generated upon the first `accumulate` call of the `federated_aggregate`.
        if seed is None:
            real_seed = tf.convert_to_tensor(SEED_SENTINEL, dtype=tf.int64)
        elif tf.is_tensor(seed):
            if seed.dtype != tf.int64:
                real_seed = tf.cast(seed, dtype=tf.int64)
        else:
            real_seed = tf.convert_to_tensor(seed, dtype=tf.int64)

        def zero_for_tensor_type(t: computation_types.TensorType):
            """Add an extra first dimension to create a tensor that collects samples.

      The first dimension will have size `0` for the algebraic zero, resulting
      in an "empty" tensor. This will be conctenated as samples fill the
      reservoir.

      Args:
        t: A `tff.TensorType` to build a sampling zero value for.

      Returns:
        A tensor whose rank is one larger than before, and whose first dimension
        is zero.

      Raises:
        `TypeError` if `t` is not a `tff.TensorType`.
      """
            if not t.is_tensor():
                raise TypeError(
                    f'Cannot create zero for non TesnorType: {type(t)}')
            return tf.zeros([0] + t.shape, dtype=t.dtype)

        if sample_value_type.is_tensor():
            initial_samples = zero_for_tensor_type(sample_value_type)
        elif sample_value_type.is_struct():
            initial_samples = structure.map_structure(zero_for_tensor_type,
                                                      sample_value_type)
        else:
            raise TypeError(
                'Cannot build initial reservoir for structure that has '
                'types other than StructWithPythonType or TensorType, '
                f'got {sample_value_type!r}.')
        return collections.OrderedDict(random_seed=tf.fill(dims=(2, ),
                                                           value=real_seed),
                                       random_values=tf.zeros([0], tf.int32),
                                       samples=initial_samples)