示例#1
0
 def sequence_reduce(self, value, zero, op):
     """Implements `sequence_reduce` as defined in `api/intrinsics.py`."""
     value = value_impl.to_value(value, None, self._context_stack)
     zero = value_impl.to_value(zero, None, self._context_stack)
     op = value_impl.to_value(op, None, self._context_stack)
     # Check if the value is a federated sequence that should be reduced
     # under a `federated_map`.
     if value.type_signature.is_federated():
         is_federated_sequence = True
         value_member_type = value.type_signature.member
         value_member_type.check_sequence()
         zero_member_type = zero.type_signature.member
     else:
         is_federated_sequence = False
         value.type_signature.check_sequence()
     value = value_impl.ValueImpl.get_comp(value)
     zero = value_impl.ValueImpl.get_comp(zero)
     op = value_impl.ValueImpl.get_comp(op)
     if not is_federated_sequence:
         comp = building_block_factory.create_sequence_reduce(
             value, zero, op)
         comp = self._bind_comp_as_reference(comp)
         return value_impl.ValueImpl(comp, self._context_stack)
     else:
         ref_type = computation_types.StructType(
             [value_member_type, zero_member_type])
         ref = building_blocks.Reference('arg', ref_type)
         arg1 = building_blocks.Selection(ref, index=0)
         arg2 = building_blocks.Selection(ref, index=1)
         call = building_block_factory.create_sequence_reduce(
             arg1, arg2, op)
         fn = building_blocks.Lambda(ref.name, ref.type_signature, call)
         fn_value_impl = value_impl.ValueImpl(fn, self._context_stack)
         args = building_blocks.Struct([value, zero])
         return self.federated_map(fn_value_impl, args)
示例#2
0
    def sequence_map(self, fn, arg):
        """Implements `sequence_map` as defined in `api/intrinsics.py`."""
        fn = value_impl.to_value(fn, None, self._context_stack)
        py_typecheck.check_type(fn.type_signature,
                                computation_types.FunctionType)
        arg = value_impl.to_value(arg, None, self._context_stack)

        if arg.type_signature.is_sequence():
            fn = value_impl.ValueImpl.get_comp(fn)
            arg = value_impl.ValueImpl.get_comp(arg)
            comp = building_block_factory.create_sequence_map(fn, arg)
            comp = self._bind_comp_as_reference(comp)
            return value_impl.ValueImpl(comp, self._context_stack)
        elif arg.type_signature.is_federated():
            parameter_type = computation_types.SequenceType(
                fn.type_signature.parameter)
            result_type = computation_types.SequenceType(
                fn.type_signature.result)
            intrinsic_type = computation_types.FunctionType(
                (fn.type_signature, parameter_type), result_type)
            intrinsic = building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_MAP.uri, intrinsic_type)
            intrinsic_impl = value_impl.ValueImpl(intrinsic,
                                                  self._context_stack)
            local_fn = value_utils.get_curried(intrinsic_impl)(fn)
            return self.federated_map(local_fn, arg)
        else:
            raise TypeError(
                'Cannot apply `tff.sequence_map()` to a value of type {}.'.
                format(arg.type_signature))
示例#3
0
    def sequence_sum(self, value):
        """Implements `sequence_sum` as defined in `api/intrinsics.py`."""
        value = value_impl.to_value(value, None, self._context_stack)
        if value.type_signature.is_sequence():
            element_type = value.type_signature.element
        else:
            py_typecheck.check_type(value.type_signature,
                                    computation_types.FederatedType)
            py_typecheck.check_type(value.type_signature.member,
                                    computation_types.SequenceType)
            element_type = value.type_signature.member.element
        type_analysis.check_is_sum_compatible(element_type)

        if value.type_signature.is_sequence():
            value = value_impl.ValueImpl.get_comp(value)
            comp = building_block_factory.create_sequence_sum(value)
            comp = self._bind_comp_as_reference(comp)
            return value_impl.ValueImpl(comp, self._context_stack)
        elif value.type_signature.is_federated():
            intrinsic_type = computation_types.FunctionType(
                value.type_signature.member,
                value.type_signature.member.element)
            intrinsic = building_blocks.Intrinsic(
                intrinsic_defs.SEQUENCE_SUM.uri, intrinsic_type)
            intrinsic_impl = value_impl.ValueImpl(intrinsic,
                                                  self._context_stack)
            return self.federated_map(intrinsic_impl, value)
        else:
            raise TypeError(
                'Cannot apply `tff.sequence_sum()` to a value of type {}.'.
                format(value.type_signature))
示例#4
0
 def test_to_value_for_list(self):
   x = value_impl.ValueImpl(
       building_blocks.Reference('foo', tf.int32),
       context_stack_impl.context_stack)
   y = value_impl.ValueImpl(
       building_blocks.Reference('bar', tf.bool),
       context_stack_impl.context_stack)
   v = value_impl.to_value([x, y], None, context_stack_impl.context_stack)
   self.assertIsInstance(v, value_base.Value)
   self.assertEqual(str(v), '<foo,bar>')
示例#5
0
 def test_to_value_for_structure(self):
   x = value_impl.ValueImpl(
       building_blocks.Reference('foo', tf.int32),
       context_stack_impl.context_stack)
   y = value_impl.ValueImpl(
       building_blocks.Reference('bar', tf.bool),
       context_stack_impl.context_stack)
   v = value_impl.to_value(
       structure.Struct([('a', x), ('b', y)]), None,
       context_stack_impl.context_stack)
   self.assertIsInstance(v, value_base.Value)
   self.assertEqual(str(v), '<a=foo,b=bar>')
示例#6
0
 def test_to_value_for_ordered_dict(self):
   x = value_impl.ValueImpl(
       building_blocks.Reference('foo', tf.int32),
       context_stack_impl.context_stack)
   y = value_impl.ValueImpl(
       building_blocks.Reference('bar', tf.bool),
       context_stack_impl.context_stack)
   v = value_impl.to_value(
       collections.OrderedDict([('b', y), ('a', x)]), None,
       context_stack_impl.context_stack)
   self.assertIsInstance(v, value_base.Value)
   self.assertEqual(str(v), '<b=bar,a=foo>')
示例#7
0
 def test_to_value_for_named_tuple(self):
   x = value_impl.ValueImpl(
       building_blocks.Reference('foo', tf.int32),
       context_stack_impl.context_stack)
   y = value_impl.ValueImpl(
       building_blocks.Reference('bar', tf.bool),
       context_stack_impl.context_stack)
   v = value_impl.to_value(
       collections.namedtuple('_', 'a b')(x, y), None,
       context_stack_impl.context_stack)
   self.assertIsInstance(v, value_base.Value)
   self.assertEqual(str(v), '<a=foo,b=bar>')
示例#8
0
 def test_value_impl_with_plus(self):
   x = value_impl.ValueImpl(
       building_blocks.Reference('x', tf.int32),
       context_stack_impl.context_stack)
   y = value_impl.ValueImpl(
       building_blocks.Reference('y', tf.int32),
       context_stack_impl.context_stack)
   z = x + y
   self.assertIsInstance(z, value_base.Value)
   self.assertEqual(str(z.type_signature), 'int32')
   self.assertEqual(str(z), 'fc_FEDERATED_symbol_0')
   bindings = value_impl.ValueImpl.get_context_stack(z).current.symbol_bindings
   self.assertLen(bindings, 1)
   name, comp = bindings[0]
   self.assertEqual(name, 'fc_FEDERATED_symbol_0')
   self.assertEqual(comp.compact_representation(), 'generic_plus(<x,y>)')
示例#9
0
    def federated_map_all_equal(self, fn, arg):
        """`federated_map` with the `all_equal` bit set in the `arg` and return."""
        # TODO(b/113112108): Possibly lift the restriction that the mapped value
        # must be placed at the clients after adding support for placement labels
        # in the federated types, and expanding the type specification of the
        # intrinsic this is based on to work with federated values of arbitrary
        # placement.
        arg = value_impl.to_value(arg, None, self._context_stack)
        arg = value_utils.ensure_federated_value(arg,
                                                 placement_literals.CLIENTS,
                                                 'value to be mapped')

        fn = value_impl.to_value(fn,
                                 None,
                                 self._context_stack,
                                 parameter_type_hint=arg.type_signature.member)

        py_typecheck.check_type(fn, value_base.Value)
        py_typecheck.check_type(fn.type_signature,
                                computation_types.FunctionType)
        if not fn.type_signature.parameter.is_assignable_from(
                arg.type_signature.member):
            raise TypeError(
                'The mapping function expects a parameter of type {}, but member '
                'constituents of the mapped value are of incompatible type {}.'
                .format(fn.type_signature.parameter,
                        arg.type_signature.member))

        fn = value_impl.ValueImpl.get_comp(fn)
        arg = value_impl.ValueImpl.get_comp(arg)
        comp = building_block_factory.create_federated_map_all_equal(fn, arg)
        comp = self._bind_comp_as_reference(comp)
        return value_impl.ValueImpl(comp, self._context_stack)
示例#10
0
def get_curried(fn):
    """Returns a curried version of function `fn` that takes a parameter tuple.

  For functions `fn` of types <T1,T2,....,Tn> -> U, the result is a function
  of the form T1 -> (T2 -> (T3 -> .... (Tn -> U) ... )).

  Note: No attempt is made at avoiding naming conflicts in cases where `fn`
  contains references. The arguments of the curriend function are named `argN`
  with `N` starting at 0.

  Args:
    fn: A value of a functional TFF type.

  Returns:
    A value that represents the curried form of `fn`.
  """
    py_typecheck.check_type(fn, value_base.Value)
    py_typecheck.check_type(fn.type_signature, computation_types.FunctionType)
    py_typecheck.check_type(fn.type_signature.parameter,
                            computation_types.StructType)
    param_elements = structure.to_elements(fn.type_signature.parameter)
    references = []
    for idx, (_, elem_type) in enumerate(param_elements):
        references.append(
            building_blocks.Reference('arg{}'.format(idx), elem_type))
    result = building_blocks.Call(value_impl.ValueImpl.get_comp(fn),
                                  building_blocks.Struct(references))
    for ref in references[::-1]:
        result = building_blocks.Lambda(ref.name, ref.type_signature, result)
    return value_impl.ValueImpl(result,
                                value_impl.ValueImpl.get_context_stack(fn))
示例#11
0
 def federated_secure_sum(self, value, bitwidth):
     """Implements `federated_secure_sum` as defined in `api/intrinsics.py`."""
     value = value_impl.to_value(value, None, self._context_stack)
     value = value_utils.ensure_federated_value(value,
                                                placement_literals.CLIENTS,
                                                'value to be summed')
     type_analysis.check_is_structure_of_integers(value.type_signature)
     bitwidth_value = value_impl.to_value(bitwidth, None,
                                          self._context_stack)
     value_member_type = value.type_signature.member
     bitwidth_type = bitwidth_value.type_signature
     if not type_analysis.is_valid_bitwidth_type_for_value_type(
             bitwidth_type, value_member_type):
         raise TypeError(
             'Expected `federated_secure_sum` parameter `bitwidth` to match '
             'the structure of `value`, with one integer bitwidth per tensor in '
             '`value`. Found `value` of `{}` and `bitwidth` of `{}`.'.
             format(value_member_type, bitwidth_type))
     if bitwidth_type.is_tensor() and value_member_type.is_struct():
         bitwidth_value = value_impl.to_value(
             structure.map_structure(lambda _: bitwidth, value_member_type),
             None, self._context_stack)
     value = value_impl.ValueImpl.get_comp(value)
     bitwidth_value = value_impl.ValueImpl.get_comp(bitwidth_value)
     comp = building_block_factory.create_federated_secure_sum(
         value, bitwidth_value)
     comp = self._bind_comp_as_reference(comp)
     return value_impl.ValueImpl(comp, self._context_stack)
示例#12
0
 def test_value_impl_with_selection(self):
   x = value_impl.ValueImpl(
       building_blocks.Reference('foo', [('bar', tf.int32), ('baz', tf.bool)]),
       context_stack_impl.context_stack)
   self.assertContainsSubset(['bar', 'baz'], dir(x))
   self.assertLen(x, 2)
   y = x.bar
   self.assertIsInstance(y, value_base.Value)
   self.assertEqual(str(y.type_signature), 'int32')
   self.assertEqual(str(y), 'foo.bar')
   z = x['baz']
   self.assertEqual(str(z.type_signature), 'bool')
   self.assertEqual(str(z), 'foo.baz')
   with self.assertRaises(AttributeError):
     _ = x.bak
   x0 = x[0]
   self.assertIsInstance(x0, value_base.Value)
   self.assertEqual(str(x0.type_signature), 'int32')
   self.assertEqual(str(x0), 'foo[0]')
   x1 = x[1]
   self.assertEqual(str(x1.type_signature), 'bool')
   self.assertEqual(str(x1), 'foo[1]')
   with self.assertRaises(IndexError):
     _ = x[2]
   with self.assertRaises(IndexError):
     _ = x[-1]
   self.assertEqual(','.join(str(e) for e in iter(x)), 'foo[0],foo[1]')
   self.assertEqual(','.join(str(e.type_signature) for e in iter(x)),
                    'int32,bool')
   with self.assertRaises(SyntaxError):
     x(10)
示例#13
0
  def test_value_impl_dir(self):
    x_comp = building_blocks.Reference('foo', tf.int32)
    x = value_impl.ValueImpl(x_comp, context_stack_impl.context_stack)

    result = dir(x)
    self.assertIsInstance(result, list)
    self.assertNotEmpty(result)
    self.assertIn('type_signature', result)
示例#14
0
 def test_value_impl_with_reference(self):
   x_comp = building_blocks.Reference('foo', tf.int32)
   x = value_impl.ValueImpl(x_comp, context_stack_impl.context_stack)
   self.assertIs(value_impl.ValueImpl.get_comp(x), x_comp)
   self.assertEqual(str(x.type_signature), 'int32')
   self.assertEqual(repr(x), 'Reference(\'foo\', TensorType(tf.int32))')
   self.assertEqual(str(x), 'foo')
   with self.assertRaises(SyntaxError):
     x(10)
示例#15
0
 def test_to_value_for_dict(self):
   x = value_impl.ValueImpl(
       building_blocks.Reference('foo', tf.int32),
       context_stack_impl.context_stack)
   y = value_impl.ValueImpl(
       building_blocks.Reference('bar', tf.bool),
       context_stack_impl.context_stack)
   v1 = value_impl.to_value({
       'a': x,
       'b': y,
   }, None, context_stack_impl.context_stack)
   self.assertIsInstance(v1, value_base.Value)
   self.assertEqual(str(v1), '<a=foo,b=bar>')
   v2 = value_impl.to_value({
       'b': y,
       'a': x,
   }, None, context_stack_impl.context_stack)
   self.assertIsInstance(v2, value_base.Value)
   self.assertEqual(str(v2), '<a=foo,b=bar>')
示例#16
0
 def test_slicing_support_non_tuple_underlying_comp(self):
   test_computation_building_blocks = building_blocks.Reference(
       'test', [tf.int32] * 5)
   v = value_impl.ValueImpl(test_computation_building_blocks,
                            context_stack_impl.context_stack)
   sliced_v = v[:4:2]
   self.assertIsInstance(sliced_v, value_base.Value)
   sliced_v = v[4:2:-1]
   self.assertIsInstance(sliced_v, value_base.Value)
   with self.assertRaisesRegex(IndexError, 'slice 0 elements'):
     _ = v[2:4:-1]
示例#17
0
    def federated_collect(self, value):
        """Implements `federated_collect` as defined in `api/intrinsics.py`."""
        value = value_impl.to_value(value, None, self._context_stack)
        value = value_utils.ensure_federated_value(value,
                                                   placement_literals.CLIENTS,
                                                   'value to be collected')

        value = value_impl.ValueImpl.get_comp(value)
        comp = building_block_factory.create_federated_collect(value)
        comp = self._bind_comp_as_reference(comp)
        return value_impl.ValueImpl(comp, self._context_stack)
示例#18
0
def federated_computation_serializer(
    parameter_name: Optional[str],
    parameter_type: Optional[computation_types.Type],
    context_stack: context_stack_base.ContextStack,
    suggested_name: Optional[str] = None,
):
    """Converts a function into a computation building block.

  Args:
    parameter_name: The name of the parameter, or `None` if there is't any.
    parameter_type: The `tff.Type` of the parameter, or `None` if there's none.
    context_stack: The context stack to use.
    suggested_name: The optional suggested name to use for the federated context
      that will be used to serialize this function's body (ideally the name of
      the underlying Python function). It might be modified to avoid conflicts.

  Yields:
    First, the argument to be passed to the function to be converted.
    Finally, a tuple of `(building_blocks.ComputationBuildingBlock,
    computation_types.Type)`: the function represented via building blocks and
    the inferred return type.
  """
    py_typecheck.check_type(context_stack, context_stack_base.ContextStack)
    if suggested_name is not None:
        py_typecheck.check_type(suggested_name, str)
    if isinstance(context_stack.current,
                  federated_computation_context.FederatedComputationContext):
        parent_context = context_stack.current
    else:
        parent_context = None
    context = federated_computation_context.FederatedComputationContext(
        context_stack, suggested_name=suggested_name, parent=parent_context)
    if parameter_name is not None:
        py_typecheck.check_type(parameter_name, str)
        parameter_name = '{}_{}'.format(context.name, str(parameter_name))
    with context_stack.install(context):
        if parameter_type is None:
            result = yield None
        else:
            result = yield (value_impl.ValueImpl(
                building_blocks.Reference(parameter_name, parameter_type),
                context_stack))
        annotated_result_type = type_conversions.infer_type(result)
        result = value_impl.to_value(result, annotated_result_type,
                                     context_stack)
        result_comp = value_impl.ValueImpl.get_comp(result)
        symbols_bound_in_context = context_stack.current.symbol_bindings
        if symbols_bound_in_context:
            result_comp = building_blocks.Block(
                local_symbols=symbols_bound_in_context, result=result_comp)
        annotated_type = computation_types.FunctionType(
            parameter_type, annotated_result_type)
        yield building_blocks.Lambda(parameter_name, parameter_type,
                                     result_comp), annotated_type
示例#19
0
 def test_slicing_support_namedtuple(self):
   x = value_impl.ValueImpl(
       building_blocks.Reference('foo', tf.int32),
       context_stack_impl.context_stack)
   y = value_impl.ValueImpl(
       building_blocks.Reference('bar', tf.bool),
       context_stack_impl.context_stack)
   v = value_impl.to_value(
       collections.namedtuple('_', 'a b')(x, y), None,
       context_stack_impl.context_stack)
   sliced_v = v[:int(len(v) / 2)]
   self.assertIsInstance(sliced_v, value_base.Value)
   sliced_v = v[:4:2]
   self.assertEqual(str(sliced_v), '<foo>')
   self.assertIsInstance(sliced_v, value_base.Value)
   sliced_v = v[4::-1]
   self.assertEqual(str(sliced_v), '<bar,foo>')
   self.assertIsInstance(sliced_v, value_base.Value)
   with self.assertRaisesRegex(IndexError, 'slice 0 elements'):
     _ = v[2:4]
示例#20
0
 def federated_sum(self, value):
     """Implements `federated_sum` as defined in `api/intrinsics.py`."""
     value = value_impl.to_value(value, None, self._context_stack)
     value = value_utils.ensure_federated_value(value,
                                                placement_literals.CLIENTS,
                                                'value to be summed')
     type_analysis.check_is_sum_compatible(value.type_signature)
     value = value_impl.ValueImpl.get_comp(value)
     comp = building_block_factory.create_federated_sum(value)
     comp = self._bind_comp_as_reference(comp)
     return value_impl.ValueImpl(comp, self._context_stack)
示例#21
0
    def federated_value(self, value, placement):
        """Implements `federated_value` as defined in `api/intrinsics.py`."""
        value = value_impl.to_value(value, None, self._context_stack)
        if type_analysis.contains(value.type_signature,
                                  lambda t: t.is_federated()):
            raise TypeError(
                'Cannot place value {} containing federated types at '
                'another placement; requested to be placed at {}.'.format(
                    value, placement))

        value = value_impl.ValueImpl.get_comp(value)
        comp = building_block_factory.create_federated_value(value, placement)
        comp = self._bind_comp_as_reference(comp)
        return value_impl.ValueImpl(comp, self._context_stack)
示例#22
0
 def test_value_impl_with_call(self):
   x = value_impl.ValueImpl(
       building_blocks.Reference(
           'foo', computation_types.FunctionType(tf.int32, tf.bool)),
       context_stack_impl.context_stack)
   y = value_impl.ValueImpl(
       building_blocks.Reference('bar', tf.int32),
       context_stack_impl.context_stack)
   z = x(y)
   self.assertIsInstance(z, value_base.Value)
   self.assertEqual(str(z.type_signature), 'bool')
   self.assertEqual(str(z), 'fc_FEDERATED_symbol_0')
   bound_symbols = context_stack_impl.context_stack.current.symbol_bindings
   self.assertLen(bound_symbols, 1)
   self.assertEqual(bound_symbols[0][0], str(z))
   self.assertEqual(str(bound_symbols[0][1]), 'foo(bar)')
   with self.assertRaises(TypeError):
     x()
   w = value_impl.ValueImpl(
       building_blocks.Reference('bak', tf.float32),
       context_stack_impl.context_stack)
   with self.assertRaises(TypeError):
     x(w)
示例#23
0
    def federated_map(self, fn, arg):
        """Implements `federated_map` as defined in `api/intrinsics.py`."""
        # TODO(b/113112108): Possibly lift the restriction that the mapped value
        # must be placed at the server or clients. Would occur after adding support
        # for placement labels in the federated types, and expanding the type
        # specification of the intrinsic this is based on to work with federated
        # values of arbitrary placement.

        arg = value_impl.to_value(arg, None, self._context_stack)
        arg = value_utils.ensure_federated_value(arg,
                                                 label='value to be mapped')

        fn = value_impl.to_value(fn,
                                 None,
                                 self._context_stack,
                                 parameter_type_hint=arg.type_signature.member)

        py_typecheck.check_type(fn, value_base.Value)
        py_typecheck.check_type(fn.type_signature,
                                computation_types.FunctionType)
        if not fn.type_signature.parameter.is_assignable_from(
                arg.type_signature.member):
            raise TypeError(
                'The mapping function expects a parameter of type {}, but member '
                'constituents of the mapped value are of incompatible type {}.'
                .format(fn.type_signature.parameter,
                        arg.type_signature.member))

        # TODO(b/144384398): Change structure to one that maps the placement type
        # to the building_block function that fits it, in a way that allows the
        # appropriate type checks.
        if arg.type_signature.placement is placement_literals.SERVER:
            if not arg.type_signature.all_equal:
                raise TypeError(
                    'Arguments placed at {} should be equal at all locations.'.
                    format(placement_literals.SERVER))
            fn = value_impl.ValueImpl.get_comp(fn)
            arg = value_impl.ValueImpl.get_comp(arg)
            comp = building_block_factory.create_federated_apply(fn, arg)
        elif arg.type_signature.placement is placement_literals.CLIENTS:
            fn = value_impl.ValueImpl.get_comp(fn)
            arg = value_impl.ValueImpl.get_comp(arg)
            comp = building_block_factory.create_federated_map(fn, arg)
        else:
            raise TypeError(
                'Expected `arg` to have a type with a supported placement, '
                'found {}.'.format(arg.type_signature.placement))

        comp = self._bind_comp_as_reference(comp)
        return value_impl.ValueImpl(comp, self._context_stack)
示例#24
0
    def federated_broadcast(self, value):
        """Implements `federated_broadcast` as defined in `api/intrinsics.py`."""
        value = value_impl.to_value(value, None, self._context_stack)
        value = value_utils.ensure_federated_value(value,
                                                   placement_literals.SERVER,
                                                   'value to be broadcasted')

        if not value.type_signature.all_equal:
            raise TypeError(
                'The broadcasted value should be equal at all locations.')

        value = value_impl.ValueImpl.get_comp(value)
        comp = building_block_factory.create_federated_broadcast(value)
        comp = self._bind_comp_as_reference(comp)
        return value_impl.ValueImpl(comp, self._context_stack)
示例#25
0
    def test_get_curried(self):
        add_numbers = value_impl.ValueImpl(
            building_blocks.ComputationBuildingBlock.from_proto(
                computation_impl.ComputationImpl.get_proto(
                    computations.tf_computation(
                        lambda a, b: tf.add(a, b),  # pylint: disable=unnecessary-lambda
                        [tf.int32, tf.int32]))),
            _context_stack)

        curried = value_utils.get_curried(add_numbers)
        self.assertEqual(str(curried.type_signature),
                         '(int32 -> (int32 -> int32))')

        comp, _ = tree_transformations.uniquify_compiled_computation_names(
            value_impl.ValueImpl.get_comp(curried))
        self.assertEqual(comp.compact_representation(),
                         '(arg0 -> (arg1 -> comp#1(<arg0,arg1>)))')
示例#26
0
    def federated_mean(self, value, weight):
        """Implements `federated_mean` as defined in `api/intrinsics.py`."""
        # TODO(b/113112108): Possibly relax the constraints on numeric types, and
        # inject implicit casts where appropriate. For instance, we might want to
        # allow `tf.int32` values as the input, and automatically cast them to
        # `tf.float321 before invoking the average, thus producing a floating-point
        # result.

        # TODO(b/120439632): Possibly allow the weight to be either structured or
        # non-scalar, e.g., for the case of averaging a convolutional layer, when
        # we would want to use a different weight for every filter, and where it
        # might be cumbersome for users to have to manually slice and assemble a
        # variable.

        value = value_impl.to_value(value, None, self._context_stack)
        value = value_utils.ensure_federated_value(value,
                                                   placement_literals.CLIENTS,
                                                   'value to be averaged')
        if not type_analysis.is_average_compatible(value.type_signature):
            raise TypeError(
                'The value type {} is not compatible with the average operator.'
                .format(value.type_signature))

        if weight is not None:
            weight = value_impl.to_value(weight, None, self._context_stack)
            weight = value_utils.ensure_federated_value(
                weight, placement_literals.CLIENTS,
                'weight to use in averaging')
            py_typecheck.check_type(weight.type_signature.member,
                                    computation_types.TensorType)
            if weight.type_signature.member.shape.ndims != 0:
                raise TypeError(
                    'The weight type {} is not a federated scalar.'.format(
                        weight.type_signature))
            if not (weight.type_signature.member.dtype.is_integer
                    or weight.type_signature.member.dtype.is_floating):
                raise TypeError(
                    'The weight type {} is not a federated integer or floating-point '
                    'tensor.'.format(weight.type_signature))

        value = value_impl.ValueImpl.get_comp(value)
        if weight is not None:
            weight = value_impl.ValueImpl.get_comp(weight)
        comp = building_block_factory.create_federated_mean(value, weight)
        comp = self._bind_comp_as_reference(comp)
        return value_impl.ValueImpl(comp, self._context_stack)
示例#27
0
 def test_value_impl_with_lambda(self):
   arg_name = 'arg'
   arg_type = [('f', computation_types.FunctionType(tf.int32, tf.int32)),
               ('x', tf.int32)]
   result_value = (lambda arg: arg.f(arg.f(arg.x)))(
       value_impl.ValueImpl(
           building_blocks.Reference(arg_name, arg_type),
           context_stack_impl.context_stack))
   self.assertIsInstance(result_value, value_base.Value)
   self.assertEqual(str(result_value.type_signature), 'int32')
   self.assertEqual(str(result_value), 'fc_FEDERATED_symbol_1')
   bound_symbols = context_stack_impl.context_stack.current.symbol_bindings
   self.assertLen(bound_symbols, 2)
   self.assertEqual(bound_symbols[1][0], 'fc_FEDERATED_symbol_1')
   self.assertEqual(str(bound_symbols[1][1]), 'arg.f(fc_FEDERATED_symbol_0)')
   self.assertEqual(bound_symbols[0][0], 'fc_FEDERATED_symbol_0')
   self.assertEqual(str(bound_symbols[0][1]), 'arg.f(arg.x)')
示例#28
0
def ensure_federated_value(value, placement=None, label=None):
    """Ensures `value` is a federated value placed at `placement`.

  If `value` is not a `computation_types.FederatedType` but is a
  `computation_types.StructType` that can be converted via `federated_zip`
  to a `computation_types.FederatedType`, inserts the call to `federated_zip`
  and returns the result. If `value` cannot be converted, raises a TypeError.

  Args:
    value: A `value_impl.ValueImpl` to check and convert to a federated value if
      possible.
    placement: The expected placement. If None, any placement is allowed.
    label: An optional string label that describes `value`.

  Returns:
    The value as a federated value, automatically zipping if necessary.

  Raises:
    TypeError: if `value` is not a `FederatedType` and cannot be converted to
      a `FederatedType` with `federated_zip`.
  """
    py_typecheck.check_type(value, value_impl.ValueImpl)
    if label is not None:
        py_typecheck.check_type(label, str)

    if not value.type_signature.is_federated():
        comp = value_impl.ValueImpl.get_comp(value)
        context_stack = value_impl.ValueImpl.get_context_stack(value)
        try:
            zipped = building_block_factory.create_federated_zip(comp)
        except (TypeError, ValueError):
            raise TypeError(
                'The {l} must be a FederatedType or implicitly convertible '
                'to a FederatedType (got a {t}).'.format(
                    l=label if label else 'value', t=comp.type_signature))
        value = value_impl.ValueImpl(zipped, context_stack)

    if placement and value.type_signature.placement is not placement:
        raise TypeError(
            'The {} should be placed at {}, but it is placed at {}.'.format(
                label if label else 'value', placement,
                value.type_signature.placement))

    return value
示例#29
0
 def test_value_impl_with_tuple(self):
   x_comp = building_blocks.Reference('foo', tf.int32)
   y_comp = building_blocks.Reference('bar', tf.bool)
   z = value_impl.ValueImpl(
       building_blocks.Struct([x_comp, ('y', y_comp)]),
       context_stack_impl.context_stack)
   self.assertIsInstance(z, value_base.Value)
   self.assertEqual(str(z.type_signature), '<int32,y=bool>')
   self.assertEqual(str(z), '<foo,y=bar>')
   self.assertContainsSubset(['y'], dir(z))
   self.assertEqual(str(z.y), 'bar')
   self.assertIs(value_impl.ValueImpl.get_comp(z.y), y_comp)
   self.assertLen(z, 2)
   self.assertEqual(str(z[0]), 'foo')
   self.assertIs(value_impl.ValueImpl.get_comp(z[0]), x_comp)
   self.assertEqual(str(z['y']), 'bar')
   self.assertIs(value_impl.ValueImpl.get_comp(z['y']), y_comp)
   self.assertEqual(','.join(str(e) for e in iter(z)), 'foo,bar')
   with self.assertRaises(SyntaxError):
     z(10)
示例#30
0
    def federated_zip(self, value):
        """Implements `federated_zip` as defined in `api/intrinsics.py`."""
        # TODO(b/113112108): Extend this to accept *args.

        # TODO(b/113112108): We use the iterate/unwrap approach below because
        # our type system is not powerful enough to express the concept of
        # "an operation that takes tuples of T of arbitrary length", and therefore
        # the intrinsic federated_zip must only take a fixed number of arguments,
        # here fixed at 2. There are other potential approaches to getting around
        # this problem (e.g. having the operator act on sequences and thereby
        # sidestepping the issue) which we may want to explore.
        value = value_impl.to_value(value, None, self._context_stack)
        py_typecheck.check_type(value, value_base.Value)
        py_typecheck.check_type(value.type_signature,
                                computation_types.StructType)

        value = value_impl.ValueImpl.get_comp(value)
        comp = building_block_factory.create_federated_zip(value)
        comp = self._bind_comp_as_reference(comp)
        return value_impl.ValueImpl(comp, self._context_stack)