コード例 #1
0
ファイル: intrinsics_test.py プロジェクト: yf817/federated
    def test_sequence_sum(self):
        @tff.federated_computation(tff.SequenceType(tf.int32))
        def foo1(x):
            return tff.sequence_sum(x)

        self.assertEqual(str(foo1.type_signature), '(int32* -> int32)')

        @tff.federated_computation(
            tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER))
        def foo2(x):
            return tff.sequence_sum(x)

        self.assertEqual(str(foo2.type_signature),
                         '(int32*@SERVER -> int32@SERVER)')

        @tff.federated_computation(
            tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS))
        def foo3(x):
            return tff.sequence_sum(x)

        self.assertEqual(str(foo3.type_signature),
                         '({int32*}@CLIENTS -> {int32}@CLIENTS)')
コード例 #2
0
    def test_run_encoded_sum(self):
        value = np.array([0.0, 1.0, 2.0, -1.0])
        value_spec = tf.TensorSpec(value.shape, tf.as_dtype(value.dtype))
        value_type = tff.to_type(value_spec)
        encoder = te.encoders.as_gather_encoder(te.encoders.identity(),
                                                value_spec)
        gather_fn = encoding_utils.build_encoded_sum(value, encoder)
        initial_state = gather_fn.initialize()

        @tff.federated_computation(
            tff.FederatedType(gather_fn._initialize_fn.type_signature.result,
                              tff.SERVER),
            tff.FederatedType(value_type, tff.CLIENTS))
        def call_gather(state, value):
            return gather_fn(state, value)

        _, value_sum = call_gather(initial_state, [value, value])
        self.assertAllClose(2 * value, value_sum)

        _, value_sum = call_gather(initial_state, [value, -value])
        self.assertAllClose(0 * value, value_sum)

        _, value_sum = call_gather(initial_state, [value, 2 * value])
        self.assertAllClose(3 * value, value_sum)
コード例 #3
0
    def test_call_returned_directly_creates_canonical_form(self):
        @tff.federated_computation
        def init_fn():
            return tff.federated_value(42, tff.SERVER)

        @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER),
                                   tff.FederatedType(
                                       tff.SequenceType(tf.float32),
                                       tff.CLIENTS))
        def next_fn(server_state, client_data):
            broadcast_state = tff.federated_broadcast(server_state)

            @tff.tf_computation(tf.int32, tff.SequenceType(tf.float32))
            @tf.function
            def some_transform(x, y):
                del y  # Unused
                return x + 1

            client_update = tff.federated_map(some_transform,
                                              (broadcast_state, client_data))
            aggregate_update = tff.federated_sum(client_update)
            server_output = tff.federated_value(1234, tff.SERVER)
            return aggregate_update, server_output

        @tff.federated_computation(
            tff.FederatedType(tf.int32, tff.SERVER),
            tff.FederatedType(computation_types.SequenceType(tf.float32),
                              tff.CLIENTS))
        def nested_next_fn(server_state, client_data):
            return next_fn(server_state, client_data)

        iterative_process = computation_utils.IterativeProcess(
            init_fn, nested_next_fn)
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(
            iterative_process)
        self.assertIsInstance(cf, canonical_form.CanonicalForm)
コード例 #4
0
  def test_federated_max_on_nested_scalars(self):
    tuple_type = tff.NamedTupleType([
        ('a', tf.int32),
        ('b', tf.int32),
    ])

    @tff.federated_computation(tff.FederatedType(tuple_type, tff.CLIENTS))
    def call_federated_max(value):
      return federated_aggregations.federated_max(value)

    test_type = collections.namedtuple('NestedScalars', ['a', 'b'])
    value = call_federated_max(
        [test_type(1, 5), test_type(2, 3),
         test_type(1, 8)])
    self.assertDictEqual(value._asdict(), {'a': 2, 'b': 8})
コード例 #5
0
ファイル: intrinsics_test.py プロジェクト: zmHe/federated
    def test_federated_zip_with_twenty_elements_local_executor(self):

        n = 20
        n_clients = 2

        @tff.federated_computation([tff.FederatedType(tf.int32, tff.CLIENTS)] *
                                   n)
        def foo(x):
            return tff.federated_zip(x)

        data = [list(range(n_clients)) for _ in range(n)]

        # This would not have ever returned when local executor was scaling
        # factorially with number of elements zipped
        foo(data)
コード例 #6
0
ファイル: intrinsics_test.py プロジェクト: zmHe/federated
    def test_federated_apply_raises_warning(self):
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter('always')

            @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER))
            def foo(x):
                return tff.federated_apply(
                    tff.tf_computation(lambda x: x * x, tf.int32), x)

            self.assertLen(w, 1)
            self.assertIsInstance(w[0].category(), DeprecationWarning)
            self.assertIn('tff.federated_apply() is deprecated',
                          str(w[0].message))
            self.assertEqual(str(foo.type_signature),
                             '(int32@SERVER -> int32@SERVER)')
コード例 #7
0
  def test_sequence_reduce(self):
    add_numbers = tff.tf_computation(tf.add, [tf.int32, tf.int32])

    @tff.federated_computation(tff.SequenceType(tf.int32))
    def foo1(x):
      return tff.sequence_reduce(x, 0, add_numbers)

    self.assertEqual(str(foo1.type_signature), '(int32* -> int32)')

    @tff.federated_computation(
        tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER))
    def foo2(x):
      return tff.sequence_reduce(x, 0, add_numbers)

    self.assertEqual(
        str(foo2.type_signature), '(int32*@SERVER -> int32@SERVER)')

    @tff.federated_computation(
        tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS))
    def foo3(x):
      return tff.sequence_reduce(x, 0, add_numbers)

    self.assertEqual(
        str(foo3.type_signature), '({int32*}@CLIENTS -> {int32}@CLIENTS)')
コード例 #8
0
  def test_federated_min_on_nested_scalars(self):
    tuple_type = tff.NamedTupleType([
        ('x', tf.float32),
        ('y', tf.float32),
    ])

    @tff.federated_computation(tff.FederatedType(tuple_type, tff.CLIENTS))
    def call_federated_min(value):
      return federated_aggregations.federated_min(value)

    test_type = collections.namedtuple('NestedScalars', ['x', 'y'])
    value = call_federated_min(
        [test_type(0.0, 1.0),
         test_type(-1.0, 5.0),
         test_type(2.0, -10.0)])
    self.assertDictEqual(value._asdict(), {'x': -1.0, 'y': -10.0})
コード例 #9
0
  def test_federated_max_nested_tensor_value(self):
    tuple_type = tff.NamedTupleType([
        ('a', (tf.int32, [2])),
        ('b', (tf.int32, [3])),
    ])

    @tff.federated_computation(tff.FederatedType(tuple_type, tff.CLIENTS))
    def call_federated_max(value):
      return federated_aggregations.federated_max(value)

    test_type = collections.namedtuple('NestedScalars', ['a', 'b'])
    client1 = test_type(
        np.array([4, 5], dtype=np.int32), np.array([1, -2, 3], dtype=np.int32))
    client2 = test_type(
        np.array([9, 0], dtype=np.int32), np.array([5, 1, -2], dtype=np.int32))
    value = call_federated_max([client1, client2])
    self.assertCountEqual(value[0], [9, 5])
    self.assertCountEqual(value[1], [5, 1, 3])
コード例 #10
0
def get_iterative_process_for_canonical_form(cf):
    """Creates `tff.utils.IterativeProcess` from a canonical form.

  Args:
    cf: An instance of `tff.backends.mapreduce.CanonicalForm`.

  Returns:
    An instance of `tff.utils.IterativeProcess` that corresponds to `cf`.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
    py_typecheck.check_type(cf, canonical_form.CanonicalForm)

    @tff.federated_computation
    def init_computation():
        return tff.federated_value(cf.initialize(), tff.SERVER)

    @tff.federated_computation(init_computation.type_signature.result,
                               tff.FederatedType(
                                   cf.work.type_signature.parameter[0],
                                   tff.CLIENTS))
    def next_computation(arg):
        """The logic of a single MapReduce sprocessing round."""
        s1 = arg[0]
        c1 = arg[1]
        s2 = tff.federated_apply(cf.prepare, s1)
        c2 = tff.federated_broadcast(s2)
        c3 = tff.federated_zip([c1, c2])
        c4 = tff.federated_map(cf.work, c3)
        c5 = c4[0]
        c6 = c4[1]
        s3 = tff.federated_aggregate(c5, cf.zero(), cf.accumulate, cf.merge,
                                     cf.report)
        s4 = tff.federated_zip([s1, s3])
        s5 = tff.federated_apply(cf.update, s4)
        s6 = s5[0]
        s7 = s5[1]
        return s6, s7, c6

    return computation_utils.IterativeProcess(init_computation,
                                              next_computation)
コード例 #11
0
  def test_federated_broadcast_with_client_int(self):
    with self.assertRaises(TypeError):

      @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS, True))
      def _(x):
        return tff.federated_broadcast(x)
コード例 #12
0
def check_and_pack_before_aggregate_type_signature(type_spec,
                                                   previously_packed_types):
    """Checks types inferred from `before_aggregate` and packs in `previously_packed_types`.

  After splitting the `after_broadcast` portion of a
  `tff.utils.IterativeProcess` into `before_aggregate` and `after_aggregate`,
  `before_aggregate` should have type signature
  `<<s1,c1>,c2> -> <c5,zero,accumulate,merge,report>`. This
  function validates `c1`, `s1` and `c2` against the existing entries in
  `previously_packed_types`, then packs `s5`, `zero`, `accumulate`, `merge` and
  `report`.

  Args:
    type_spec: The `type_signature` attribute of the `before_aggregate` portion
      of the `tff.utils.IterativeProcess` from which we are looking to extract
      an instance of `canonical_form.CanonicalForm`.
    previously_packed_types: Dict containing the information from `next` and
      `before_broadcast` in the iterative process we are parsing.

  Returns:
    A `dict` packing the types which can be inferred from `type_spec`.

  Raises:
    TypeError: If `type_signature` is incompatible with
    `previously_packed_types`.
  """
    should_raise = False
    if not (isinstance(type_spec, tff.FunctionType)
            and isinstance(type_spec.parameter, tff.NamedTupleType)):
        should_raise = True
    if not (isinstance(type_spec.parameter[0], tff.NamedTupleType)
            and len(type_spec.parameter[0]) == 2 and type_spec.parameter[0][0]
            == previously_packed_types['s1_type'] and type_spec.parameter[0][1]
            == previously_packed_types['c1_type']):
        should_raise = True
    if not (isinstance(type_spec.parameter[1], tff.FederatedType)
            and type_spec.parameter[1].placement == tff.CLIENTS
            and type_spec.parameter[1].member
            == previously_packed_types['s2_type'].member):
        should_raise = True
    if not (isinstance(type_spec.result, tff.NamedTupleType)
            and len(type_spec.result) == 5
            and isinstance(type_spec.result[0], tff.FederatedType)
            and type_spec.result[0].placement == tff.CLIENTS and
            tff_framework.is_tensorflow_compatible_type(type_spec.result[1])
            and type_spec.result[2] == tff.FunctionType(
                [type_spec.result[1], type_spec.result[0].member],
                type_spec.result[1]) and type_spec.result[3]
            == tff.FunctionType([type_spec.result[1], type_spec.result[1]],
                                type_spec.result[1])
            and type_spec.result[4].parameter == type_spec.result[1]
            and tff_framework.is_tensorflow_compatible_type(
                type_spec.result[4].result)):
        should_raise = True
    if should_raise:
        # TODO(b/121290421): These error messages, and indeed the 'track boolean and
        # raise once' logic of these methods as well, is intended to be provisional
        # and revisited when we've seen the compilation pipeline fail more clearly,
        # or maybe preferably iteratively improved as new failure modes are
        # encountered.
        raise TypeError(
            'Encountered a type error while checking '
            '`before_aggregate`. Expected a type signature of the '
            'form `<<s1,c1>,c2> -> <c5,zero,accumulate,merge,report>`, '
            'where `s1` matches {}, `c1` matches {}, and `c2` matches '
            'the result of broadcasting {}, as defined in '
            '`canonical_form.CanonicalForm`. Found type signature {}.'.format(
                previously_packed_types['s1_type'],
                previously_packed_types['c1_type'],
                previously_packed_types['s2_type'], type_spec))
    newly_determined_types = {}
    c2_type = type_spec.parameter[1]
    newly_determined_types['c2_type'] = c2_type
    c3_type = tff.FederatedType(
        [previously_packed_types['c1_type'].member, c2_type.member],
        tff.CLIENTS)
    newly_determined_types['c3_type'] = c3_type
    c5_type = type_spec.result[0]
    zero_type = tff.FunctionType(None, type_spec.result[1])
    accumulate_type = type_spec.result[2]
    merge_type = type_spec.result[3]
    report_type = type_spec.result[4]
    newly_determined_types['c5_type'] = c5_type
    newly_determined_types['zero_type'] = zero_type
    newly_determined_types['accumulate_type'] = accumulate_type
    newly_determined_types['merge_type'] = merge_type
    newly_determined_types['report_type'] = report_type
    newly_determined_types['s3_type'] = tff.FederatedType(
        report_type.result, tff.SERVER)
    c4_type = tff.FederatedType([
        newly_determined_types['c5_type'].member,
        previously_packed_types['c6_type'].member
    ], tff.CLIENTS)
    newly_determined_types['c4_type'] = c4_type
    newly_determined_types['work_type'] = tff.FunctionType(
        c3_type.member, c4_type.member)
    return dict(
        itertools.chain(six.iteritems(previously_packed_types),
                        six.iteritems(newly_determined_types)))
コード例 #13
0
def check_and_pack_after_aggregate_type_signature(type_spec,
                                                  previously_packed_types):
    """Checks types inferred from `after_aggregate` and packs in `previously_packed_types`.

  After splitting the `next` portion of a `tff.utils.IterativeProcess` all the
  way down, `after_aggregate` should have
  type signature `<<<s1,c1>,c2>,s3> -> <s6,s7,c6>`. This
  function validates every element of the above, extracting and packing in
  addition types of `s3` and `s4`.

  Args:
    type_spec: The `type_signature` attribute of the `after_aggregate` portion
      of the `tff.utils.IterativeProcess` from which we are looking to extract
      an instance of `canonical_form.CanonicalForm`.
    previously_packed_types: Dict containing the information from `next`,
      `before_broadcast` and `before_aggregate` in the iterative process we are
      parsing.

  Returns:
    A `dict` packing the types which can be inferred from `type_spec`.

  Raises:
    TypeError: If `type_signature` is incompatible with
    `previously_packed_types`.
  """
    should_raise = False
    if not (type_spec.parameter[0][0][0] == previously_packed_types['s1_type']
            and type_spec.parameter[0][0][1]
            == previously_packed_types['c1_type']
            and type_spec.parameter[0][1] == previously_packed_types['c2_type']
            and type_spec.parameter[1] == previously_packed_types['s3_type']):
        should_raise = True
    if not (type_spec.result[0] == previously_packed_types['s6_type']
            and type_spec.result[1] == previously_packed_types['s7_type']):
        should_raise = True
    if len(
            type_spec.result
    ) == 3 and type_spec.result[2] != previously_packed_types['c6_type']:
        should_raise = True
    if should_raise:
        # TODO(b/121290421): These error messages, and indeed the 'track boolean and
        # raise once' logic of these methods as well, is intended to be provisional
        # and revisited when we've seen the compilation pipeline fail more clearly,
        # or maybe preferably iteratively improved as new failure modes are
        # encountered.
        raise TypeError(
            'Encountered a type error while checking `after_aggregate`; '
            'expected a type signature of the form '
            '`<<<s1,c1>,c2>,s3> -> <s6,s7,c6>`, where s1 matches {}, '
            'c1 matches {}, c2 matches {}, s3 matches {}, s6 matches '
            '{}, s7 matches {}, c6 matches {}, as defined in '
            '`canonical_form.CanonicalForm`. Encountered a type signature '
            '{}.'.format(previously_packed_types['s1_type'],
                         previously_packed_types['c1_type'],
                         previously_packed_types['c2_type'],
                         previously_packed_types['s3_type'],
                         previously_packed_types['s6_type'],
                         previously_packed_types['s7_type'],
                         previously_packed_types['c6_type'], type_spec))
    s4_type = tff.FederatedType([
        previously_packed_types['s1_type'].member,
        previously_packed_types['s3_type'].member
    ], tff.SERVER)
    s5_type = tff.FederatedType([
        previously_packed_types['s6_type'].member,
        previously_packed_types['s7_type'].member
    ], tff.SERVER)
    newly_determined_types = {}
    newly_determined_types['s4_type'] = s4_type
    newly_determined_types['s5_type'] = s5_type
    newly_determined_types['update_type'] = tff.FunctionType(
        s4_type.member, s5_type.member)
    c3_type = tff.FederatedType([
        previously_packed_types['c1_type'].member,
        previously_packed_types['c2_type'].member
    ], tff.CLIENTS)
    newly_determined_types['c3_type'] = c3_type
    return dict(
        itertools.chain(six.iteritems(previously_packed_types),
                        six.iteritems(newly_determined_types)))
コード例 #14
0
def zip_selection_as_argument_to_lower_level_lambda(comp, selected_index_lists):
  r"""Binds selections from the param of `comp` as params to lower-level lambda.

  Notice that `comp` must be a `tff_framework.Lambda`.

  The returned pattern is quite important here; given an input lambda `Comp`,
  we will return an equivalent structure of the form:


                                    Lambda(x)
                                       |
                                      Call
                                    /      \
                              Lambda        <Selections from x>

  Where <Selections from x> represents a tuple of selections from the parameter
  `x`, as specified by `selected_index_lists`. This transform is necessary in
  order to isolate spurious dependence on arguments that are not in fact used,
  for example after we have separated processing on the server from that which
  happens on the clients, but the server-processing still declares some
  parameters placed at the clients.

  `selected_index_lists` must be a list of lists. Each list represents
  a sequence of selections to the parameter of `comp`. For example, if `var`
  is the parameter of `comp`, the list `[0, 1, 0]` would represent the
  selection `x[0][1][0]`. The elements of these inner lists must be integers;
  that is, the selections must be positional. Notice we do not allow for tuples
  due to automatic unwrapping.

  Args:
    comp: Instance of `tff_framework.Lambda`, whose parameters we wish to rebind
      to a different lambda.
    selected_index_lists: 2-d list of `int`s, specifying the parameters of
      `comp` which we wish to rebind as the parameter to a lower-level lambda.

  Returns:
    An instance of `tff_framework.Lambda`, equivalent to `comp`, satisfying the
    pattern above.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  py_typecheck.check_type(selected_index_lists, list)
  for selection_list in selected_index_lists:
    py_typecheck.check_type(selection_list, list)
    for selected_element in selection_list:
      py_typecheck.check_type(selected_element, int)
  original_comp = comp
  comp = _prepare_for_rebinding(comp)

  top_level_parameter_type = comp.type_signature.parameter
  name_generator = tff_framework.unique_name_generator(comp)
  top_level_parameter_name = comp.parameter_name
  top_level_parameter_reference = tff_framework.Reference(
      top_level_parameter_name, comp.parameter_type)

  type_list = []
  for selection_list in selected_index_lists:
    try:
      selected_type = top_level_parameter_type
      for selection in selection_list:
        selected_type = selected_type[selection]
      type_list.append(selected_type)
    except TypeError:
      six.reraise(
          TypeError,
          TypeError(
              'You have tried to bind a variable to a nonexistent index in your '
              'lambda parameter type; the selection defined by {} is '
              'inadmissible for the lambda parameter type {}, in the comp {}.'
              .format(selection_list, top_level_parameter_type, original_comp)),
          sys.exc_info()[2])

  if not all(isinstance(x, tff.FederatedType) for x in type_list):
    raise TypeError(
        'All selected arguments should be of federated type; your selections '
        'have resulted in the list of types {}'.format(type_list))
  placement = type_list[0].placement
  if not all(x.placement is placement for x in type_list):
    raise ValueError(
        'In order to zip the argument to the lower-level lambda together, all '
        'selected arguments should be at the same placement. Your selections '
        'have resulted in the list of types {}'.format(type_list))

  arg_to_lower_level_lambda_list = []
  for selection_tuple in selected_index_lists:
    selected_comp = top_level_parameter_reference
    for selection in selection_tuple:
      selected_comp = tff_framework.Selection(selected_comp, index=selection)
    arg_to_lower_level_lambda_list.append(selected_comp)
  zip_arg = tff_framework.create_federated_zip(
      tff_framework.Tuple(arg_to_lower_level_lambda_list))

  zip_type = tff.FederatedType([x.member for x in type_list],
                               placement=placement)
  ref_to_zip = tff_framework.Reference(six.next(name_generator), zip_type)

  selections_from_zip = [
      _construct_selection_from_federated_tuple(ref_to_zip, x, name_generator)
      for x in range(len(selected_index_lists))
  ]

  def _replace_selections_with_new_bindings(inner_comp):
    """Identifies selection pattern and replaces with new binding.

    Detecting this pattern is the most brittle part of this rebinding function.
    It relies on pattern-matching, and right now we cannot guarantee that this
    pattern is present in every situation we wish to replace with a new
    binding.

    Args:
      inner_comp: Instance of `tff_framework.ComputationBuildingBlock` in which
        we wish to replace the selections specified by `selected_index_lists`
        with the parallel new bindings from `selections_from_zip`.

    Returns:
      A possibly transformed version of `inner_comp` with nodes matching the
      selection patterns replaced by their new bindings.
    """
    # TODO(b/135541729): Either come up with a preprocessing way to enforce
    # this is sufficient, or rework the should_transform predicate.
    for idx, tup in enumerate(selected_index_lists):
      selection = inner_comp  # Empty selection
      tuple_pattern_matched = True
      for selected_index in tup[::-1]:
        if isinstance(
            selection,
            tff_framework.Selection) and selection.index == selected_index:
          selection = selection.source
        else:
          tuple_pattern_matched = False
          break
      if tuple_pattern_matched:
        if isinstance(selection, tff_framework.Reference
                     ) and selection.name == top_level_parameter_name:
          return selections_from_zip[idx], True
    return inner_comp, False

  variables_rebound_in_result, _ = tff_framework.transform_postorder(
      comp.result, _replace_selections_with_new_bindings)
  lambda_with_zipped_param = tff_framework.Lambda(ref_to_zip.name,
                                                  ref_to_zip.type_signature,
                                                  variables_rebound_in_result)
  _check_for_missed_binding(comp, lambda_with_zipped_param)

  zipped_lambda_called = tff_framework.Call(lambda_with_zipped_param, zip_arg)
  constructed_lambda = tff_framework.Lambda(comp.parameter_name,
                                            comp.parameter_type,
                                            zipped_lambda_called)
  names_uniquified, _ = tff_framework.uniquify_reference_names(
      constructed_lambda)
  return names_uniquified
コード例 #15
0
ファイル: intrinsics_test.py プロジェクト: zmHe/federated
class IntrinsicsTest(parameterized.TestCase):
    def test_federated_broadcast_with_server_all_equal_int(self):
        @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER))
        def foo(x):
            return tff.federated_broadcast(x)

        self.assertEqual(str(foo.type_signature),
                         '(int32@SERVER -> int32@CLIENTS)')

    def test_federated_broadcast_with_server_non_all_equal_int(self):
        with self.assertRaises(TypeError):

            @tff.federated_computation(
                tff.FederatedType(tf.int32, tff.SERVER, all_equal=False))
            def _(x):
                return tff.federated_broadcast(x)

    def test_federated_broadcast_with_client_int(self):
        with self.assertRaises(TypeError):

            @tff.federated_computation(
                tff.FederatedType(tf.int32, tff.CLIENTS, True))
            def _(x):
                return tff.federated_broadcast(x)

    def test_federated_broadcast_with_non_federated_val(self):
        with self.assertRaises(TypeError):

            @tff.federated_computation(tf.int32)
            def _(x):
                return tff.federated_broadcast(x)

    def test_federated_map_with_client_all_equal_int(self):
        @tff.federated_computation(
            tff.FederatedType(tf.int32, tff.CLIENTS, True))
        def foo(x):
            return tff.federated_map(
                tff.tf_computation(lambda x: x > 10, tf.int32), x)

        self.assertEqual(str(foo.type_signature),
                         '(int32@CLIENTS -> {bool}@CLIENTS)')

    def test_federated_map_with_client_non_all_equal_int(self):
        @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS))
        def foo(x):
            return tff.federated_map(
                tff.tf_computation(lambda x: x > 10, tf.int32), x)

        self.assertEqual(str(foo.type_signature),
                         '({int32}@CLIENTS -> {bool}@CLIENTS)')

    def test_federated_map_with_server_int(self):
        @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER))
        def foo(x):
            return tff.federated_map(
                tff.tf_computation(lambda x: x > 10, tf.int32), x)

        self.assertEqual(str(foo.type_signature),
                         '(int32@SERVER -> bool@SERVER)')

    def test_federated_map_injected_zip_with_server_int(self):
        @tff.federated_computation([
            tff.FederatedType(tf.int32, tff.SERVER),
            tff.FederatedType(tf.int32, tff.SERVER)
        ])
        def foo(x, y):
            return tff.federated_map(
                tff.tf_computation(lambda x, y: x > 10, [tf.int32, tf.int32]),
                [x, y])

        self.assertEqual(str(foo.type_signature),
                         '(<int32@SERVER,int32@SERVER> -> bool@SERVER)')

    def test_federated_map_injected_zip_fails_different_placements(self):
        def foo(x, y):
            return tff.federated_map(
                tff.tf_computation(lambda x, y: x > 10, [tf.int32, tf.int32]),
                [x, y])

        with self.assertRaisesRegex(
                TypeError,
                'You cannot apply federated_map on nested values with mixed '
                'placements.'):

            tff.federated_computation(foo, [
                tff.FederatedType(tf.int32, tff.SERVER),
                tff.FederatedType(tf.int32, tff.CLIENTS)
            ])

    def test_federated_map_with_non_federated_val(self):
        with self.assertRaises(TypeError):

            @tff.federated_computation(tf.int32)
            def _(x):
                return tff.federated_map(
                    tff.tf_computation(lambda x: x > 10, tf.int32), x)

    def test_federated_sum_with_client_int(self):
        @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS))
        def foo(x):
            return tff.federated_sum(x)

        self.assertEqual(str(foo.type_signature),
                         '({int32}@CLIENTS -> int32@SERVER)')

    def test_federated_sum_with_client_string(self):
        with self.assertRaises(TypeError):

            @tff.federated_computation(
                tff.FederatedType(tf.string, tff.CLIENTS))
            def _(x):
                return tff.federated_sum(x)

    def test_federated_sum_with_server_int(self):
        with self.assertRaises(TypeError):

            @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER))
            def _(x):
                return tff.federated_sum(x)

    def test_federated_zip_with_client_non_all_equal_int_and_bool(self):
        @tff.federated_computation([
            tff.FederatedType(tf.int32, tff.CLIENTS),
            tff.FederatedType(tf.bool, tff.CLIENTS, True)
        ])
        def foo(x, y):
            return tff.federated_zip([x, y])

        self.assertEqual(
            str(foo.type_signature),
            '(<{int32}@CLIENTS,bool@CLIENTS> -> {<int32,bool>}@CLIENTS)')

    def test_federated_zip_with_single_unnamed_int_client(self):
        @tff.federated_computation([
            tff.FederatedType(tf.int32, tff.CLIENTS),
        ])
        def foo(x):
            return tff.federated_zip(x)

        self.assertEqual(str(foo.type_signature),
                         '(<{int32}@CLIENTS> -> {<int32>}@CLIENTS)')

    def test_federated_zip_with_single_unnamed_int_server(self):
        @tff.federated_computation([
            tff.FederatedType(tf.int32, tff.SERVER),
        ])
        def foo(x):
            return tff.federated_zip(x)

        self.assertEqual(str(foo.type_signature),
                         '(<int32@SERVER> -> <int32>@SERVER)')

    def test_federated_zip_with_single_named_bool_clients(self):
        @tff.federated_computation([
            ('a', tff.FederatedType(tf.bool, tff.CLIENTS)),
        ])
        def foo(x):
            return tff.federated_zip(x)

        self.assertEqual(str(foo.type_signature),
                         '(<a={bool}@CLIENTS> -> {<a=bool>}@CLIENTS)')

    def test_federated_zip_with_single_named_bool_server(self):
        @tff.federated_computation([
            ('a', tff.FederatedType(tf.bool, tff.SERVER)),
        ])
        def foo(x):
            return tff.federated_zip(x)

        self.assertEqual(str(foo.type_signature),
                         '(<a=bool@SERVER> -> <a=bool>@SERVER)')

    def test_federated_zip_with_names_client_non_all_equal_int_and_bool(self):
        @tff.federated_computation([
            tff.FederatedType(tf.int32, tff.CLIENTS),
            tff.FederatedType(tf.bool, tff.CLIENTS, True)
        ])
        def foo(x, y):
            a = {'x': x, 'y': y}
            return tff.federated_zip(a)

        self.assertEqual(
            str(foo.type_signature),
            '(<{int32}@CLIENTS,bool@CLIENTS> -> {<x=int32,y=bool>}@CLIENTS)')

    def test_federated_zip_with_client_all_equal_int_and_bool(self):
        @tff.federated_computation([
            tff.FederatedType(tf.int32, tff.CLIENTS, True),
            tff.FederatedType(tf.bool, tff.CLIENTS, True)
        ])
        def foo(x, y):
            return tff.federated_zip([x, y])

        self.assertEqual(
            str(foo.type_signature),
            '(<int32@CLIENTS,bool@CLIENTS> -> {<int32,bool>}@CLIENTS)')

    def test_federated_zip_with_names_client_all_equal_int_and_bool(self):
        @tff.federated_computation([
            tff.FederatedType(tf.int32, tff.CLIENTS, True),
            tff.FederatedType(tf.bool, tff.CLIENTS, True)
        ])
        def foo(arg):
            a = {'x': arg[0], 'y': arg[1]}
            return tff.federated_zip(a)

        self.assertEqual(
            str(foo.type_signature),
            '(<int32@CLIENTS,bool@CLIENTS> -> {<x=int32,y=bool>}@CLIENTS)')

    def test_federated_zip_with_server_int_and_bool(self):
        @tff.federated_computation([
            tff.FederatedType(tf.int32, tff.SERVER),
            tff.FederatedType(tf.bool, tff.SERVER)
        ])
        def foo(x, y):
            return tff.federated_zip([x, y])

        self.assertEqual(
            str(foo.type_signature),
            '(<int32@SERVER,bool@SERVER> -> <int32,bool>@SERVER)')

    def test_federated_zip_with_names_server_int_and_bool(self):
        @tff.federated_computation([
            ('a', tff.FederatedType(tf.int32, tff.SERVER)),
            ('b', tff.FederatedType(tf.bool, tff.SERVER)),
        ])
        def foo(arg):
            return tff.federated_zip(arg)

        self.assertEqual(
            str(foo.type_signature),
            '(<a=int32@SERVER,b=bool@SERVER> -> <a=int32,b=bool>@SERVER)')

    def test_federated_zip_error_different_placements(self):
        with self.assertRaisesRegex(
                TypeError, r'The elements .* must be placed at SERVER. '
                r'Element placements: \(SERVER,CLIENTS\)'):

            @tff.federated_computation([
                ('a', tff.FederatedType(tf.int32, tff.SERVER)),
                ('b', tff.FederatedType(tf.bool, tff.CLIENTS)),
            ])
            def _(arg):
                return tff.federated_zip(arg)

    def test_federated_collect_with_client_int(self):
        @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS))
        def foo(x):
            return tff.federated_collect(x)

        self.assertEqual(str(foo.type_signature),
                         '({int32}@CLIENTS -> int32*@SERVER)')

    def test_federated_collect_with_server_int_fails(self):
        with self.assertRaises(TypeError):

            @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER))
            def _(x):
                return tff.federated_collect(x)

    def test_federated_mean_with_client_float32_without_weight(self):
        @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
        def foo(x):
            return tff.federated_mean(x)

        self.assertEqual(str(foo.type_signature),
                         '({float32}@CLIENTS -> float32@SERVER)')

    def test_federated_mean_with_all_equal_client_float32_without_weight(self):
        federated_all_equal_float = tff.FederatedType(tf.float32,
                                                      tff.CLIENTS,
                                                      all_equal=True)

        @tff.federated_computation(federated_all_equal_float)
        def foo(x):
            return tff.federated_mean(x)

        self.assertEqual(str(foo.type_signature),
                         '(float32@CLIENTS -> float32@SERVER)')

    def test_federated_mean_with_all_equal_client_float32_with_weight(self):
        federated_all_equal_float = tff.FederatedType(tf.float32,
                                                      tff.CLIENTS,
                                                      all_equal=True)

        @tff.federated_computation(federated_all_equal_float)
        def foo(x):
            return tff.federated_mean(x, x)

        self.assertEqual(str(foo.type_signature),
                         '(float32@CLIENTS -> float32@SERVER)')

    def test_federated_mean_with_client_tuple_with_int32_weight(self):
        @tff.federated_computation([
            tff.FederatedType([('x', tf.float64), ('y', tf.float64)],
                              tff.CLIENTS),
            tff.FederatedType(tf.int32, tff.CLIENTS)
        ])
        def foo(x, y):
            return tff.federated_mean(x, y)

        self.assertEqual(
            str(foo.type_signature),
            '(<{<x=float64,y=float64>}@CLIENTS,{int32}@CLIENTS> '
            '-> <x=float64,y=float64>@SERVER)')

    def test_federated_mean_with_client_int32_fails(self):
        with self.assertRaises(TypeError):

            @tff.federated_computation(tff.FederatedType(
                tf.int32, tff.CLIENTS))
            def _(x):
                return tff.federated_mean(x)

    def test_federated_mean_with_string_weight_fails(self):
        with self.assertRaises(TypeError):

            @tff.federated_computation([
                tff.FederatedType(tf.float32, tff.CLIENTS),
                tff.FederatedType(tf.string, tff.CLIENTS)
            ])
            def _(x, y):
                return tff.federated_mean(x, y)

    def test_federated_aggregate_with_client_int(self):
        # The representation used during the aggregation process will be a named
        # tuple with 2 elements - the integer 'total' that represents the sum of
        # elements encountered, and the integer element 'count'.
        # pylint: disable=invalid-name
        Accumulator = collections.namedtuple('Accumulator', 'total count')
        # pylint: enable=invalid-name
        accumulator_type = tff.NamedTupleType(Accumulator(tf.int32, tf.int32))

        # The operator to use during the first stage simply adds an element to the
        # total and updates the count.
        @tff.tf_computation([accumulator_type, tf.int32])
        def accumulate(accu, elem):
            return Accumulator(accu.total + elem, accu.count + 1)

        # The operator to use during the second stage simply adds total and count.
        @tff.tf_computation([accumulator_type, accumulator_type])
        def merge(x, y):
            return Accumulator(x.total + y.total, x.count + y.count)

        # The operator to use during the final stage simply computes the ratio.
        @tff.tf_computation(accumulator_type)
        def report(accu):
            return tf.cast(accu.total, tf.float32) / tf.cast(
                accu.count, tf.float32)

        @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS))
        def foo(x):
            return tff.federated_aggregate(x, Accumulator(0, 0), accumulate,
                                           merge, report)

        self.assertEqual(str(foo.type_signature),
                         '({int32}@CLIENTS -> float32@SERVER)')

    def test_federated_aggregate_with_federated_zero_fails(self):
        @tff.federated_computation()
        def build_federated_zero():
            return tff.federated_value(0, tff.SERVER)

        @tff.tf_computation([tf.int32, tf.int32])
        def accumulate(accu, elem):
            return accu + elem

        # The operator to use during the second stage simply adds total and count.
        @tff.tf_computation([tf.int32, tf.int32])
        def merge(x, y):
            return x + y

        # The operator to use during the final stage simply computes the ratio.
        @tff.tf_computation(tf.int32)
        def report(accu):
            return accu

        def foo(x):
            return tff.federated_aggregate(x, build_federated_zero(),
                                           accumulate, merge, report)

        with self.assertRaisesRegex(
                TypeError, 'Expected `zero` to be assignable to type int32, '
                'but was of incompatible type int32@SERVER'):
            tff.federated_computation(foo,
                                      tff.FederatedType(tf.int32, tff.CLIENTS))

    def test_federated_aggregate_with_unknown_dimension(self):
        Accumulator = collections.namedtuple('Accumulator', ['samples'])  # pylint: disable=invalid-name
        accumulator_type = tff.NamedTupleType(
            Accumulator(samples=tff.TensorType(dtype=tf.int32, shape=[None])))

        @tff.tf_computation()
        def build_empty_accumulator():
            return Accumulator(samples=tf.zeros(shape=[0], dtype=tf.int32))

        # The operator to use during the first stage simply adds an element to the
        # tensor, increasing its size.
        @tff.tf_computation([accumulator_type, tf.int32])
        def accumulate(accu, elem):
            return Accumulator(samples=tf.concat(
                [accu.samples, tf.expand_dims(elem, axis=0)], axis=0))

        # The operator to use during the second stage simply adds total and count.
        @tff.tf_computation([accumulator_type, accumulator_type])
        def merge(x, y):
            return Accumulator(
                samples=tf.concat([x.samples, y.samples], axis=0))

        # The operator to use during the final stage simply computes the ratio.
        @tff.tf_computation(accumulator_type)
        def report(accu):
            return accu

        @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS))
        def foo(x):
            return tff.federated_aggregate(x, build_empty_accumulator(),
                                           accumulate, merge, report)

        self.assertEqual(str(foo.type_signature),
                         '({int32}@CLIENTS -> <samples=int32[?]>@SERVER)')

    def test_federated_reduce_with_tf_add_raw_constant(self):
        @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS))
        def foo(x):
            plus = tff.tf_computation(tf.add, [tf.int32, tf.int32])
            return tff.federated_reduce(x, 0, plus)

        self.assertEqual(str(foo.type_signature),
                         '({int32}@CLIENTS -> int32@SERVER)')

    def test_num_over_temperature_threshold_example(self):
        @tff.federated_computation([
            tff.FederatedType(tf.float32, tff.CLIENTS),
            tff.FederatedType(tf.float32, tff.SERVER)
        ])
        def foo(temperatures, threshold):
            return tff.federated_sum(
                tff.federated_map(
                    tff.tf_computation(
                        lambda x, y: tf.cast(tf.greater(x, y), tf.int32),
                        [tf.float32, tf.float32]),
                    [temperatures,
                     tff.federated_broadcast(threshold)]))

        self.assertEqual(
            str(foo.type_signature),
            '(<{float32}@CLIENTS,float32@SERVER> -> int32@SERVER)')

    @parameterized.named_parameters(('test_n_2', 2), ('test_n_3', 3),
                                    ('test_n_5', 5))
    def test_n_tuple_federated_zip_tensor_args(self, n):
        fed_type = tff.FederatedType(tf.int32, tff.CLIENTS)
        initial_tuple_type = tff.NamedTupleType([fed_type] * n)
        final_fed_type = tff.FederatedType([tf.int32] * n, tff.CLIENTS)
        function_type = tff.FunctionType(initial_tuple_type, final_fed_type)
        type_string = str(function_type)

        @tff.federated_computation([tff.FederatedType(tf.int32, tff.CLIENTS)] *
                                   n)
        def foo(x):
            return tff.federated_zip(x)

        self.assertEqual(str(foo.type_signature), type_string)

    @parameterized.named_parameters(
        ('test_n_2_int', 2, tff.FederatedType(tf.int32, tff.CLIENTS)),
        ('test_n_3_int', 3, tff.FederatedType(tf.int32, tff.CLIENTS)),
        ('test_n_5_int', 5, tff.FederatedType(tf.int32, tff.CLIENTS)),
        ('test_n_2_tuple', 2,
         tff.FederatedType([tf.int32, tf.int32], tff.CLIENTS)),
        ('test_n_3_tuple', 3,
         tff.FederatedType([tf.int32, tf.int32], tff.CLIENTS)),
        ('test_n_5_tuple', 5,
         tff.FederatedType([tf.int32, tf.int32], tff.CLIENTS)))
    def test_named_n_tuple_federated_zip(self, n, fed_type):
        initial_tuple_type = tff.NamedTupleType([fed_type] * n)
        named_fed_type = tff.FederatedType([(str(k), fed_type.member)
                                            for k in range(n)], tff.CLIENTS)
        mixed_fed_type = tff.FederatedType(
            [(str(k), fed_type.member) if k % 2 == 0 else fed_type.member
             for k in range(n)], tff.CLIENTS)
        named_function_type = tff.FunctionType(initial_tuple_type,
                                               named_fed_type)
        mixed_function_type = tff.FunctionType(initial_tuple_type,
                                               mixed_fed_type)
        named_type_string = str(named_function_type)
        mixed_type_string = str(mixed_function_type)

        @tff.federated_computation([fed_type] * n)
        def foo(x):
            arg = {str(k): x[k] for k in range(n)}
            return tff.federated_zip(arg)

        self.assertEqual(str(foo.type_signature), named_type_string)

        @tff.federated_computation([fed_type] * n)
        def bar(x):
            arg = anonymous_tuple.AnonymousTuple([
                (str(k), x[k]) if k % 2 == 0 else (None, x[k])
                for k in range(n)
            ])
            return tff.federated_zip(arg)

        self.assertEqual(str(bar.type_signature), mixed_type_string)

    @parameterized.named_parameters([
        ('test_n_' + str(n) + '_m_' + str(m), n, m)
        for n, m in itertools.product([1, 2, 3], [1, 2, 3])
    ])
    def test_n_tuple_federated_zip_mixed_args(self, n, m):
        tuple_fed_type = tff.FederatedType([tf.int32, tf.int32], tff.CLIENTS)
        single_fed_type = tff.FederatedType(tf.int32, tff.CLIENTS)
        initial_tuple_type = tff.NamedTupleType([tuple_fed_type] * n +
                                                [single_fed_type] * m)
        final_fed_type = tff.FederatedType([[tf.int32, tf.int32]] * n +
                                           [tf.int32] * m, tff.CLIENTS)
        function_type = tff.FunctionType(initial_tuple_type, final_fed_type)
        type_string = str(function_type)

        @tff.federated_computation([
            tff.FederatedType(tff.NamedTupleType([tf.int32, tf.int32]),
                              tff.CLIENTS)
        ] * n + [tff.FederatedType(tf.int32, tff.CLIENTS)] * m)
        def baz(x):
            return tff.federated_zip(x)

        self.assertEqual(str(baz.type_signature), type_string)

    def test_federated_apply_raises_warning(self):
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter('always')

            @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER))
            def foo(x):
                return tff.federated_apply(
                    tff.tf_computation(lambda x: x * x, tf.int32), x)

            self.assertLen(w, 1)
            self.assertIsInstance(w[0].category(), DeprecationWarning)
            self.assertIn('tff.federated_apply() is deprecated',
                          str(w[0].message))
            self.assertEqual(str(foo.type_signature),
                             '(int32@SERVER -> int32@SERVER)')

    def test_federated_value_with_bool_on_clients(self):
        @tff.federated_computation(tf.bool)
        def foo(x):
            return tff.federated_value(x, tff.CLIENTS)

        self.assertEqual(str(foo.type_signature), '(bool -> bool@CLIENTS)')

    def test_federated_value_raw_np_scalar(self):
        @tff.federated_computation
        def test_np_values():
            floatv = np.float64(0)
            tff_float = tff.federated_value(floatv, tff.SERVER)
            self.assertEqual(str(tff_float.type_signature), 'float64@SERVER')
            intv = np.int64(0)
            tff_int = tff.federated_value(intv, tff.SERVER)
            self.assertEqual(str(tff_int.type_signature), 'int64@SERVER')
            return (tff_float, tff_int)

        floatv, intv = test_np_values()
        self.assertEqual(floatv, 0.0)
        self.assertEqual(intv, 0)

    def test_federated_value_raw_tf_scalar_variable(self):
        v = tf.Variable(initial_value=0., name='test_var')
        with self.assertRaisesRegex(
                TypeError, 'TensorFlow construct (.*) has been '
                'encountered in a federated context.'):
            _ = tff.federated_value(v, tff.SERVER)

    def test_federated_value_with_bool_on_server(self):
        @tff.federated_computation(tf.bool)
        def foo(x):
            return tff.federated_value(x, tff.SERVER)

        self.assertEqual(str(foo.type_signature), '(bool -> bool@SERVER)')

    def test_sequence_sum(self):
        @tff.federated_computation(tff.SequenceType(tf.int32))
        def foo1(x):
            return tff.sequence_sum(x)

        self.assertEqual(str(foo1.type_signature), '(int32* -> int32)')

        @tff.federated_computation(
            tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER))
        def foo2(x):
            return tff.sequence_sum(x)

        self.assertEqual(str(foo2.type_signature),
                         '(int32*@SERVER -> int32@SERVER)')

        @tff.federated_computation(
            tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS))
        def foo3(x):
            return tff.sequence_sum(x)

        self.assertEqual(str(foo3.type_signature),
                         '({int32*}@CLIENTS -> {int32}@CLIENTS)')

    def test_sequence_map(self):
        @tff.tf_computation(tf.int32)
        def over_threshold(x):
            return x > 10

        @tff.federated_computation(tff.SequenceType(tf.int32))
        def foo1(x):
            return tff.sequence_map(over_threshold, x)

        self.assertEqual(str(foo1.type_signature), '(int32* -> bool*)')

        @tff.federated_computation(
            tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER))
        def foo2(x):
            return tff.sequence_map(over_threshold, x)

        self.assertEqual(str(foo2.type_signature),
                         '(int32*@SERVER -> bool*@SERVER)')

        @tff.federated_computation(
            tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS))
        def foo3(x):
            return tff.sequence_map(over_threshold, x)

        self.assertEqual(str(foo3.type_signature),
                         '({int32*}@CLIENTS -> {bool*}@CLIENTS)')

    def test_sequence_reduce(self):
        add_numbers = tff.tf_computation(tf.add, [tf.int32, tf.int32])

        @tff.federated_computation(tff.SequenceType(tf.int32))
        def foo1(x):
            return tff.sequence_reduce(x, 0, add_numbers)

        self.assertEqual(str(foo1.type_signature), '(int32* -> int32)')

        @tff.federated_computation(
            tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER))
        def foo2(x):
            return tff.sequence_reduce(x, 0, add_numbers)

        self.assertEqual(str(foo2.type_signature),
                         '(int32*@SERVER -> int32@SERVER)')

        @tff.federated_computation(
            tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS))
        def foo3(x):
            return tff.sequence_reduce(x, 0, add_numbers)

        self.assertEqual(str(foo3.type_signature),
                         '({int32*}@CLIENTS -> {int32}@CLIENTS)')

    @core_test.executors(
        ('local', executor_stacks.create_local_executor()), )
    def test_federated_zip_with_twenty_elements_local_executor(self):

        n = 20
        n_clients = 2

        @tff.federated_computation([tff.FederatedType(tf.int32, tff.CLIENTS)] *
                                   n)
        def foo(x):
            return tff.federated_zip(x)

        data = [list(range(n_clients)) for _ in range(n)]

        # This would not have ever returned when local executor was scaling
        # factorially with number of elements zipped
        foo(data)
コード例 #16
0
  def test_federated_sum_with_client_string(self):
    with self.assertRaises(TypeError):

      @tff.federated_computation(tff.FederatedType(tf.string, tff.CLIENTS))
      def _(x):
        return tff.federated_sum(x)
コード例 #17
0
  def test_federated_collect_with_server_int_fails(self):
    with self.assertRaises(TypeError):

      @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER))
      def _(x):
        return tff.federated_collect(x)
コード例 #18
0
  def test_federated_mean_with_client_int32_fails(self):
    with self.assertRaises(TypeError):

      @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS))
      def _(x):
        return tff.federated_mean(x)
コード例 #19
0
class IntrinsicsTest(parameterized.TestCase):

  def test_federated_broadcast_with_server_all_equal_int(self):

    @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER))
    def foo(x):
      return tff.federated_broadcast(x)

    self.assertEqual(str(foo.type_signature), '(int32@SERVER -> int32@CLIENTS)')

  def test_federated_broadcast_with_server_non_all_equal_int(self):
    with self.assertRaises(TypeError):

      @tff.federated_computation(
          tff.FederatedType(tf.int32, tff.SERVER, all_equal=False))
      def _(x):
        return tff.federated_broadcast(x)

  def test_federated_broadcast_with_client_int(self):
    with self.assertRaises(TypeError):

      @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS, True))
      def _(x):
        return tff.federated_broadcast(x)

  def test_federated_broadcast_with_non_federated_val(self):
    with self.assertRaises(TypeError):

      @tff.federated_computation(tf.int32)
      def _(x):
        return tff.federated_broadcast(x)

  def test_federated_map_with_client_all_equal_int(self):

    @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS, True))
    def foo(x):
      return tff.federated_map(
          tff.tf_computation(lambda x: x > 10, tf.int32), x)

    self.assertEqual(
        str(foo.type_signature), '(int32@CLIENTS -> {bool}@CLIENTS)')

  def test_federated_map_with_client_non_all_equal_int(self):

    @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS))
    def foo(x):
      return tff.federated_map(
          tff.tf_computation(lambda x: x > 10, tf.int32), x)

    self.assertEqual(
        str(foo.type_signature), '({int32}@CLIENTS -> {bool}@CLIENTS)')

  def test_federated_map_with_non_federated_val(self):
    with self.assertRaises(TypeError):

      @tff.federated_computation(tf.int32)
      def _(x):
        return tff.federated_map(
            tff.tf_computation(lambda x: x > 10, tf.int32), x)

  def test_federated_sum_with_client_int(self):

    @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS))
    def foo(x):
      return tff.federated_sum(x)

    self.assertEqual(
        str(foo.type_signature), '({int32}@CLIENTS -> int32@SERVER)')

  def test_federated_sum_with_client_string(self):
    with self.assertRaises(TypeError):

      @tff.federated_computation(tff.FederatedType(tf.string, tff.CLIENTS))
      def _(x):
        return tff.federated_sum(x)

  def test_federated_sum_with_server_int(self):
    with self.assertRaises(TypeError):

      @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER))
      def _(x):
        return tff.federated_sum(x)

  def test_federated_zip_with_client_non_all_equal_int_and_bool(self):

    @tff.federated_computation([
        tff.FederatedType(tf.int32, tff.CLIENTS),
        tff.FederatedType(tf.bool, tff.CLIENTS, True)
    ])
    def foo(x, y):
      return tff.federated_zip([x, y])

    self.assertEqual(
        str(foo.type_signature),
        '(<{int32}@CLIENTS,bool@CLIENTS> -> {<int32,bool>}@CLIENTS)')

  def test_federated_zip_with_single_unnamed_int_client(self):

    @tff.federated_computation([
        tff.FederatedType(tf.int32, tff.CLIENTS),
    ])
    def foo(x):
      return tff.federated_zip(x)

    self.assertEqual(
        str(foo.type_signature), '(<{int32}@CLIENTS> -> {<int32>}@CLIENTS)')

  def test_federated_zip_with_single_unnamed_int_server(self):

    @tff.federated_computation([
        tff.FederatedType(tf.int32, tff.SERVER),
    ])
    def foo(x):
      return tff.federated_zip(x)

    self.assertEqual(
        str(foo.type_signature), '(<int32@SERVER> -> <int32>@SERVER)')

  def test_federated_zip_with_single_named_bool_clients(self):

    @tff.federated_computation([
        ('a', tff.FederatedType(tf.bool, tff.CLIENTS)),
    ])
    def foo(x):
      return tff.federated_zip(x)

    self.assertEqual(
        str(foo.type_signature), '(<a={bool}@CLIENTS> -> {<a=bool>}@CLIENTS)')

  def test_federated_zip_with_single_named_bool_server(self):

    @tff.federated_computation([
        ('a', tff.FederatedType(tf.bool, tff.SERVER)),
    ])
    def foo(x):
      return tff.federated_zip(x)

    self.assertEqual(
        str(foo.type_signature), '(<a=bool@SERVER> -> <a=bool>@SERVER)')

  def test_federated_zip_with_names_client_non_all_equal_int_and_bool(self):

    @tff.federated_computation([
        tff.FederatedType(tf.int32, tff.CLIENTS),
        tff.FederatedType(tf.bool, tff.CLIENTS, True)
    ])
    def foo(x, y):
      a = {'x': x, 'y': y}
      return tff.federated_zip(a)

    self.assertEqual(
        str(foo.type_signature),
        '(<{int32}@CLIENTS,bool@CLIENTS> -> {<x=int32,y=bool>}@CLIENTS)')

  def test_federated_zip_with_client_all_equal_int_and_bool(self):

    @tff.federated_computation([
        tff.FederatedType(tf.int32, tff.CLIENTS, True),
        tff.FederatedType(tf.bool, tff.CLIENTS, True)
    ])
    def foo(x, y):
      return tff.federated_zip([x, y])

    self.assertEqual(
        str(foo.type_signature),
        '(<int32@CLIENTS,bool@CLIENTS> -> {<int32,bool>}@CLIENTS)')

  def test_federated_zip_with_names_client_all_equal_int_and_bool(self):

    @tff.federated_computation([
        tff.FederatedType(tf.int32, tff.CLIENTS, True),
        tff.FederatedType(tf.bool, tff.CLIENTS, True)
    ])
    def foo(arg):
      a = {'x': arg[0], 'y': arg[1]}
      return tff.federated_zip(a)

    self.assertEqual(
        str(foo.type_signature),
        '(<int32@CLIENTS,bool@CLIENTS> -> {<x=int32,y=bool>}@CLIENTS)')

  def test_federated_zip_with_server_int_and_bool(self):

    @tff.federated_computation([
        tff.FederatedType(tf.int32, tff.SERVER),
        tff.FederatedType(tf.bool, tff.SERVER)
    ])
    def foo(x, y):
      return tff.federated_zip([x, y])

    self.assertEqual(
        str(foo.type_signature),
        '(<int32@SERVER,bool@SERVER> -> <int32,bool>@SERVER)')

  def test_federated_zip_with_names_server_int_and_bool(self):

    @tff.federated_computation([
        ('a', tff.FederatedType(tf.int32, tff.SERVER)),
        ('b', tff.FederatedType(tf.bool, tff.SERVER)),
    ])
    def foo(arg):
      return tff.federated_zip(arg)

    self.assertEqual(
        str(foo.type_signature),
        '(<a=int32@SERVER,b=bool@SERVER> -> <a=int32,b=bool>@SERVER)')

  def test_federated_collect_with_client_int(self):

    @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS))
    def foo(x):
      return tff.federated_collect(x)

    self.assertEqual(
        str(foo.type_signature), '({int32}@CLIENTS -> int32*@SERVER)')

  def test_federated_collect_with_server_int_fails(self):
    with self.assertRaises(TypeError):

      @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER))
      def _(x):
        return tff.federated_collect(x)

  def test_federated_mean_with_client_float32_without_weight(self):

    @tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
    def foo(x):
      return tff.federated_mean(x)

    self.assertEqual(
        str(foo.type_signature), '({float32}@CLIENTS -> float32@SERVER)')

  def test_federated_mean_with_client_tuple_with_int32_weight(self):

    @tff.federated_computation([
        tff.FederatedType([('x', tf.float64), ('y', tf.float64)], tff.CLIENTS),
        tff.FederatedType(tf.int32, tff.CLIENTS)
    ])
    def foo(x, y):
      return tff.federated_mean(x, y)

    self.assertEqual(
        str(foo.type_signature),
        '(<{<x=float64,y=float64>}@CLIENTS,{int32}@CLIENTS> '
        '-> <x=float64,y=float64>@SERVER)')

  def test_federated_mean_with_client_int32_fails(self):
    with self.assertRaises(TypeError):

      @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS))
      def _(x):
        return tff.federated_mean(x)

  def test_federated_mean_with_string_weight_fails(self):
    with self.assertRaises(TypeError):

      @tff.federated_computation([
          tff.FederatedType(tf.float32, tff.CLIENTS),
          tff.FederatedType(tf.string, tff.CLIENTS)
      ])
      def _(x, y):
        return tff.federated_mean(x, y)

  def test_federated_aggregate_with_client_int(self):
    # The representation used during the aggregation process will be a named
    # tuple with 2 elements - the integer 'total' that represents the sum of
    # elements encountered, and the integer element 'count'.
    # pylint: disable=invalid-name
    Accumulator = collections.namedtuple('Accumulator', 'total count')
    # pylint: enable=invalid-name
    accumulator_type = tff.NamedTupleType(Accumulator(tf.int32, tf.int32))

    # The operator to use during the first stage simply adds an element to the
    # total and updates the count.
    @tff.tf_computation([accumulator_type, tf.int32])
    def accumulate(accu, elem):
      return Accumulator(accu.total + elem, accu.count + 1)

    # The operator to use during the second stage simply adds total and count.
    @tff.tf_computation([accumulator_type, accumulator_type])
    def merge(x, y):
      return Accumulator(x.total + y.total, x.count + y.count)

    # The operator to use during the final stage simply computes the ratio.
    @tff.tf_computation(accumulator_type)
    def report(accu):
      return tf.to_float(accu.total) / tf.to_float(accu.count)

    @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS))
    def foo(x):
      return tff.federated_aggregate(x, Accumulator(0, 0), accumulate, merge,
                                     report)

    self.assertEqual(
        str(foo.type_signature), '({int32}@CLIENTS -> float32@SERVER)')

  def test_federated_reduce_with_tf_add_raw_constant(self):

    @tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS))
    def foo(x):
      plus = tff.tf_computation(tf.add, [tf.int32, tf.int32])
      return tff.federated_reduce(x, 0, plus)

    self.assertEqual(
        str(foo.type_signature), '({int32}@CLIENTS -> int32@SERVER)')

  def test_num_over_temperature_threshold_example(self):

    @tff.federated_computation([
        tff.FederatedType(tf.float32, tff.CLIENTS),
        tff.FederatedType(tf.float32, tff.SERVER)
    ])
    def foo(temperatures, threshold):
      return tff.federated_sum(
          tff.federated_map(
              tff.tf_computation(lambda x, y: tf.to_int32(tf.greater(x, y)),
                                 [tf.float32, tf.float32]),
              [temperatures, tff.federated_broadcast(threshold)]))

    self.assertEqual(
        str(foo.type_signature),
        '(<{float32}@CLIENTS,float32@SERVER> -> int32@SERVER)')

  @parameterized.named_parameters(('test_n_2', 2), ('test_n_3', 3),
                                  ('test_n_5', 5))
  def test_n_tuple_federated_zip_tensor_args(self, n):
    fed_type = tff.FederatedType(tf.int32, tff.CLIENTS)
    initial_tuple_type = tff.NamedTupleType([fed_type] * n)
    final_fed_type = tff.FederatedType([tf.int32] * n, tff.CLIENTS)
    function_type = tff.FunctionType(initial_tuple_type, final_fed_type)
    type_string = str(function_type)

    @tff.federated_computation([tff.FederatedType(tf.int32, tff.CLIENTS)] * n)
    def foo(x):
      return tff.federated_zip(x)

    self.assertEqual(str(foo.type_signature), type_string)

  @parameterized.named_parameters(
      ('test_n_2_int', 2, tff.FederatedType(tf.int32, tff.CLIENTS)),
      ('test_n_3_int', 3, tff.FederatedType(tf.int32, tff.CLIENTS)),
      ('test_n_5_int', 5, tff.FederatedType(tf.int32, tff.CLIENTS)),
      ('test_n_2_tuple', 2, tff.FederatedType([tf.int32, tf.int32],
                                              tff.CLIENTS)),
      ('test_n_3_tuple', 3, tff.FederatedType([tf.int32, tf.int32],
                                              tff.CLIENTS)),
      ('test_n_5_tuple', 5, tff.FederatedType([tf.int32, tf.int32],
                                              tff.CLIENTS)))
  def test_named_n_tuple_federated_zip(self, n, fed_type):
    initial_tuple_type = tff.NamedTupleType([fed_type] * n)
    named_fed_type = tff.FederatedType(
        [(str(k), fed_type.member) for k in range(n)], tff.CLIENTS)
    mixed_fed_type = tff.FederatedType(
        [(str(k), fed_type.member) if k % 2 == 0 else fed_type.member
         for k in range(n)], tff.CLIENTS)
    named_function_type = tff.FunctionType(initial_tuple_type, named_fed_type)
    mixed_function_type = tff.FunctionType(initial_tuple_type, mixed_fed_type)
    named_type_string = str(named_function_type)
    mixed_type_string = str(mixed_function_type)

    @tff.federated_computation([fed_type] * n)
    def foo(x):
      arg = {str(k): x[k] for k in range(n)}
      return tff.federated_zip(arg)

    self.assertEqual(str(foo.type_signature), named_type_string)

    @tff.federated_computation([fed_type] * n)
    def bar(x):
      arg = anonymous_tuple.AnonymousTuple([
          (str(k), x[k]) if k % 2 == 0 else (None, x[k]) for k in range(n)
      ])
      return tff.federated_zip(arg)

    self.assertEqual(str(bar.type_signature), mixed_type_string)

  @parameterized.named_parameters([
      ('test_n_' + str(n) + '_m_' + str(m), n, m)
      for n, m in itertools.product([1, 2, 3], [1, 2, 3])
  ])
  def test_n_tuple_federated_zip_mixed_args(self, n, m):
    tuple_fed_type = tff.FederatedType([tf.int32, tf.int32], tff.CLIENTS)
    single_fed_type = tff.FederatedType(tf.int32, tff.CLIENTS)
    initial_tuple_type = tff.NamedTupleType([tuple_fed_type] * n +
                                            [single_fed_type] * m)
    final_fed_type = tff.FederatedType([[tf.int32, tf.int32]] * n +
                                       [tf.int32] * m, tff.CLIENTS)
    function_type = tff.FunctionType(initial_tuple_type, final_fed_type)
    type_string = str(function_type)

    @tff.federated_computation([
        tff.FederatedType(
            tff.NamedTupleType([tf.int32, tf.int32]), tff.CLIENTS)
    ] * n + [tff.FederatedType(tf.int32, tff.CLIENTS)] * m)
    def baz(x):
      return tff.federated_zip(x)

    self.assertEqual(str(baz.type_signature), type_string)

  def test_federated_apply_with_int(self):

    @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER))
    def foo(x):
      return tff.federated_apply(
          tff.tf_computation(lambda x: x > 10, tf.int32), x)

    self.assertEqual(str(foo.type_signature), '(int32@SERVER -> bool@SERVER)')

  def test_federated_apply_injected_zip_int(self):

    @tff.federated_computation([
        tff.FederatedType(tf.int32, tff.SERVER),
        tff.FederatedType(tf.int32, tff.SERVER)
    ])
    def foo(x, y):
      return tff.federated_apply(
          tff.tf_computation(lambda x, y: x > 10, [tf.int32, tf.int32]), [x, y])

    self.assertEqual(
        str(foo.type_signature), '(<int32@SERVER,int32@SERVER> -> bool@SERVER)')

  def test_federated_value_with_bool_on_clients(self):

    @tff.federated_computation(tf.bool)
    def foo(x):
      return tff.federated_value(x, tff.CLIENTS)

    self.assertEqual(str(foo.type_signature), '(bool -> bool@CLIENTS)')

  def test_federated_value_raw_np_scalar(self):

    @tff.federated_computation
    def test_np_values():
      floatv = np.float64(0)
      tff_float = tff.federated_value(floatv, tff.SERVER)
      self.assertEqual(str(tff_float.type_signature), 'float64@SERVER')
      intv = np.int64(0)
      tff_int = tff.federated_value(intv, tff.SERVER)
      self.assertEqual(str(tff_int.type_signature), 'int64@SERVER')
      return (tff_float, tff_int)

    floatv, intv = test_np_values()
    self.assertEqual(floatv, 0.0)
    self.assertEqual(intv, 0)

  def test_federated_value_raw_tf_scalar_variable(self):
    v = tf.Variable(initial_value=0., name='test_var')
    with self.assertRaisesRegex(
        TypeError, 'TensorFlow construct (.*) has been '
        'encountered in a federated context.'):
      _ = tff.federated_value(v, tff.SERVER)

  def test_federated_value_with_bool_on_server(self):

    @tff.federated_computation(tf.bool)
    def foo(x):
      return tff.federated_value(x, tff.SERVER)

    self.assertEqual(str(foo.type_signature), '(bool -> bool@SERVER)')

  def test_sequence_sum(self):

    @tff.federated_computation(tff.SequenceType(tf.int32))
    def foo1(x):
      return tff.sequence_sum(x)

    self.assertEqual(str(foo1.type_signature), '(int32* -> int32)')

    @tff.federated_computation(
        tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER))
    def foo2(x):
      return tff.sequence_sum(x)

    self.assertEqual(
        str(foo2.type_signature), '(int32*@SERVER -> int32@SERVER)')

    @tff.federated_computation(
        tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS))
    def foo3(x):
      return tff.sequence_sum(x)

    self.assertEqual(
        str(foo3.type_signature), '({int32*}@CLIENTS -> {int32}@CLIENTS)')

  def test_sequence_map(self):

    @tff.tf_computation(tf.int32)
    def over_threshold(x):
      return x > 10

    @tff.federated_computation(tff.SequenceType(tf.int32))
    def foo1(x):
      return tff.sequence_map(over_threshold, x)

    self.assertEqual(str(foo1.type_signature), '(int32* -> bool*)')

    @tff.federated_computation(
        tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER))
    def foo2(x):
      return tff.sequence_map(over_threshold, x)

    self.assertEqual(
        str(foo2.type_signature), '(int32*@SERVER -> bool*@SERVER)')

    @tff.federated_computation(
        tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS))
    def foo3(x):
      return tff.sequence_map(over_threshold, x)

    self.assertEqual(
        str(foo3.type_signature), '({int32*}@CLIENTS -> {bool*}@CLIENTS)')

  def test_sequence_reduce(self):
    add_numbers = tff.tf_computation(tf.add, [tf.int32, tf.int32])

    @tff.federated_computation(tff.SequenceType(tf.int32))
    def foo1(x):
      return tff.sequence_reduce(x, 0, add_numbers)

    self.assertEqual(str(foo1.type_signature), '(int32* -> int32)')

    @tff.federated_computation(
        tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER))
    def foo2(x):
      return tff.sequence_reduce(x, 0, add_numbers)

    self.assertEqual(
        str(foo2.type_signature), '(int32*@SERVER -> int32@SERVER)')

    @tff.federated_computation(
        tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS))
    def foo3(x):
      return tff.sequence_reduce(x, 0, add_numbers)

    self.assertEqual(
        str(foo3.type_signature), '({int32*}@CLIENTS -> {int32}@CLIENTS)')
コード例 #20
0
    def test_federated_broadcast_with_server_non_all_equal_int(self):
        with self.assertRaises(TypeError):

            @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER))
            def _(x):
                return tff.federated_broadcast(x)