Exemplo n.º 1
0
  def test_named_n_tuple_federated_zip(self, n, fed_type):
    initial_tuple_type = tff.NamedTupleType([fed_type] * n)
    named_fed_type = tff.FederatedType(
        [(str(k), fed_type.member) for k in range(n)], tff.CLIENTS)
    mixed_fed_type = tff.FederatedType(
        [(str(k), fed_type.member) if k % 2 == 0 else fed_type.member
         for k in range(n)], tff.CLIENTS)
    named_function_type = tff.FunctionType(initial_tuple_type, named_fed_type)
    mixed_function_type = tff.FunctionType(initial_tuple_type, mixed_fed_type)
    named_type_string = str(named_function_type)
    mixed_type_string = str(mixed_function_type)

    @tff.federated_computation([fed_type] * n)
    def foo(x):
      arg = {str(k): x[k] for k in range(n)}
      return tff.federated_zip(arg)

    self.assertEqual(str(foo.type_signature), named_type_string)

    @tff.federated_computation([fed_type] * n)
    def bar(x):
      arg = anonymous_tuple.AnonymousTuple([
          (str(k), x[k]) if k % 2 == 0 else (None, x[k]) for k in range(n)
      ])
      return tff.federated_zip(arg)

    self.assertEqual(str(bar.type_signature), mixed_type_string)
Exemplo n.º 2
0
    def test_named_n_tuple_federated_zip(self, n, fed_type):
        initial_tuple_type = tff.NamedTupleType([fed_type] * n)
        named_fed_type = tff.FederatedType([(str(k), fed_type.member)
                                            for k in range(n)], tff.CLIENTS)
        mixed_fed_type = tff.FederatedType(
            [(str(k), fed_type.member) if k % 2 == 0 else fed_type.member
             for k in range(n)], tff.CLIENTS)
        named_function_type = tff.FunctionType(initial_tuple_type,
                                               named_fed_type)
        mixed_function_type = tff.FunctionType(initial_tuple_type,
                                               mixed_fed_type)
        named_type_string = str(named_function_type)
        mixed_type_string = str(mixed_function_type)

        @tff.federated_computation([fed_type] * n)
        def foo(x):
            arg = {str(k): x[k] for k in range(n)}
            return tff.federated_zip(arg)

        self.assertEqual(str(foo.type_signature), named_type_string)

        def _make_test_tuple(x, k):
            """Make a test tuple with a name if k is even, otherwise unnamed."""
            if k % 2 == 0:
                return str(k), x[k]
            else:
                return None, x[k]

        @tff.federated_computation([fed_type] * n)
        def bar(x):
            arg = anonymous_tuple.AnonymousTuple(
                _make_test_tuple(x, k) for k in range(n))
            return tff.federated_zip(arg)

        self.assertEqual(str(bar.type_signature), mixed_type_string)
Exemplo n.º 3
0
    def test_fed_comp_typical_usage_as_decorator_with_labeled_type(self):
        @tff.federated_computation((
            ('f', tff.FunctionType(tf.int32, tf.int32)),
            ('x', tf.int32),
        ))
        def foo(f, x):
            return f(f(x))

        @tff.tf_computation(tf.int32)
        def square(x):
            return x**2

        @tff.tf_computation(tf.int32, tf.int32)
        def square_drop_y(x, y):  # pylint: disable=unused-argument
            return x * x

        self.assertEqual(str(foo.type_signature),
                         '(<f=(int32 -> int32),x=int32> -> int32)')

        self.assertEqual(foo(square, 10), int(1e4))
        self.assertEqual(square_drop_y(square_drop_y(10, 5), 100), int(1e4))
        self.assertEqual(square_drop_y(square_drop_y(10, 100), 5), int(1e4))
        with self.assertRaisesRegexp(TypeError,
                                     'is not assignable from source type'):
            foo(square_drop_y, 10)
Exemplo n.º 4
0
def check_and_pack_before_broadcast_type_signature(type_spec,
                                                   previously_packed_types):
    """Checks types inferred from `before_broadcast` and packs in `previously_packed_types`.

  After splitting the `next` portion of a `tff.utils.IterativeProcess` into
  `before_broadcast` and `after_broadcast`, `before_broadcast` should have
  type signature `<s1, c1> -> s2`. This function validates `c1` and `s1`
  against the existing entries in `previously_packed_types`, then packs `s2`.

  Args:
    type_spec: The `type_signature` attribute of the `before_broadcast` portion
      of the `tff.utils.IterativeProcess` from which we are looking to extract
      an instance of `canonical_form.CanonicalForm`.
    previously_packed_types: Dict containing the information from `next` in the
      iterative process we are parsing.

  Returns:
    A `dict` packing the types which can be inferred from `type_signature`.

  Raises:
    TypeError: If `type_signature` is incompatible with
    `previously_packed_types`.
  """
    should_raise = False
    if not (isinstance(type_spec, tff.FunctionType)
            and isinstance(type_spec.parameter, tff.NamedTupleType)
            and len(type_spec.parameter) == 2
            and type_spec.parameter[0] == previously_packed_types['s1_type']
            and type_spec.parameter[1] == previously_packed_types['c1_type']):
        should_raise = True
    if not (isinstance(type_spec.result, tff.FederatedType)
            and type_spec.result.placement == tff.SERVER):
        should_raise = True
    if should_raise:
        # TODO(b/121290421): These error messages, and indeed the 'track boolean and
        # raise once' logic of these methods as well, is intended to be provisional
        # and revisited when we've seen the compilation pipeline fail more clearly,
        # or maybe preferably iteratively improved as new failure modes are
        # encountered.
        raise TypeError(
            'We have encountered an error checking the type signature '
            'of `before_broadcast`; expected it to have the form '
            '`<s1,c1> -> s2`, with `s1` matching {} and `c1` matching '
            '{}, as defined in `connical_form.CanonicalForm`, but '
            'encountered a type spec {}'.format(
                previously_packed_types['s1_type'],
                previously_packed_types['c1_type'], type_spec))
    s2 = type_spec.result
    newly_determined_types = {}
    newly_determined_types['s2_type'] = s2
    newly_determined_types['prepare_type'] = tff.FunctionType(
        previously_packed_types['s1_type'].member, s2.member)
    return dict(
        itertools.chain(six.iteritems(previously_packed_types),
                        six.iteritems(newly_determined_types)))
Exemplo n.º 5
0
  def test_n_tuple_federated_zip_tensor_args(self, n):
    fed_type = tff.FederatedType(tf.int32, tff.CLIENTS)
    initial_tuple_type = tff.NamedTupleType([fed_type] * n)
    final_fed_type = tff.FederatedType([tf.int32] * n, tff.CLIENTS)
    function_type = tff.FunctionType(initial_tuple_type, final_fed_type)
    type_string = str(function_type)

    @tff.federated_computation([tff.FederatedType(tf.int32, tff.CLIENTS)] * n)
    def foo(x):
      return tff.federated_zip(x)

    self.assertEqual(str(foo.type_signature), type_string)
Exemplo n.º 6
0
 def _normalize_intrinsic_bit(comp):
   """Replaces federated map all equal with federated map."""
   if comp.uri != tff_framework.FEDERATED_MAP_ALL_EQUAL.uri:
     return comp, False
   parameter_type = [
       comp.type_signature.parameter[0],
       tff.FederatedType(comp.type_signature.parameter[1].member, tff.CLIENTS)
   ]
   intrinsic_type = tff.FunctionType(
       parameter_type,
       tff.FederatedType(comp.type_signature.result.member, tff.CLIENTS))
   new_intrinsic = tff_framework.Intrinsic(tff_framework.FEDERATED_MAP.uri,
                                           intrinsic_type)
   return new_intrinsic, True
Exemplo n.º 7
0
  def test_n_tuple_federated_zip_mixed_args(self, n, m):
    tuple_fed_type = tff.FederatedType([tf.int32, tf.int32], tff.CLIENTS)
    single_fed_type = tff.FederatedType(tf.int32, tff.CLIENTS)
    initial_tuple_type = tff.NamedTupleType([tuple_fed_type] * n +
                                            [single_fed_type] * m)
    final_fed_type = tff.FederatedType([[tf.int32, tf.int32]] * n +
                                       [tf.int32] * m, tff.CLIENTS)
    function_type = tff.FunctionType(initial_tuple_type, final_fed_type)
    type_string = str(function_type)

    @tff.federated_computation([
        tff.FederatedType(
            tff.NamedTupleType([tf.int32, tf.int32]), tff.CLIENTS)
    ] * n + [tff.FederatedType(tf.int32, tff.CLIENTS)] * m)
    def baz(x):
      return tff.federated_zip(x)

    self.assertEqual(str(baz.type_signature), type_string)
Exemplo n.º 8
0
    def test_fed_comp_typical_usage_as_decorator_with_unlabeled_type(self):
        @tff.federated_computation((tff.FunctionType(tf.int32,
                                                     tf.int32), tf.int32))
        def foo(f, x):
            assert isinstance(f, tff.Value)
            assert isinstance(x, tff.Value)
            assert str(f.type_signature) == '(int32 -> int32)'
            assert str(x.type_signature) == 'int32'
            result_value = f(f(x))
            assert isinstance(result_value, tff.Value)
            assert str(result_value.type_signature) == 'int32'
            return result_value

        self.assertEqual(str(foo.type_signature),
                         '(<(int32 -> int32),int32> -> int32)')

        @tff.tf_computation(tf.int32)
        def third_power(x):
            return x**3

        self.assertEqual(foo(third_power, 10), int(1e9))
        self.assertEqual(foo(third_power, 1), 1)
Exemplo n.º 9
0
def check_and_pack_after_aggregate_type_signature(type_spec,
                                                  previously_packed_types):
    """Checks types inferred from `after_aggregate` and packs in `previously_packed_types`.

  After splitting the `next` portion of a `tff.utils.IterativeProcess` all the
  way down, `after_aggregate` should have
  type signature `<<<s1,c1>,c2>,s3> -> <s6,s7,c6>`. This
  function validates every element of the above, extracting and packing in
  addition types of `s3` and `s4`.

  Args:
    type_spec: The `type_signature` attribute of the `after_aggregate` portion
      of the `tff.utils.IterativeProcess` from which we are looking to extract
      an instance of `canonical_form.CanonicalForm`.
    previously_packed_types: Dict containing the information from `next`,
      `before_broadcast` and `before_aggregate` in the iterative process we are
      parsing.

  Returns:
    A `dict` packing the types which can be inferred from `type_spec`.

  Raises:
    TypeError: If `type_signature` is incompatible with
    `previously_packed_types`.
  """
    should_raise = False
    if not (type_spec.parameter[0][0][0] == previously_packed_types['s1_type']
            and type_spec.parameter[0][0][1]
            == previously_packed_types['c1_type']
            and type_spec.parameter[0][1] == previously_packed_types['c2_type']
            and type_spec.parameter[1] == previously_packed_types['s3_type']):
        should_raise = True
    if not (type_spec.result[0] == previously_packed_types['s6_type']
            and type_spec.result[1] == previously_packed_types['s7_type']):
        should_raise = True
    if len(
            type_spec.result
    ) == 3 and type_spec.result[2] != previously_packed_types['c6_type']:
        should_raise = True
    if should_raise:
        # TODO(b/121290421): These error messages, and indeed the 'track boolean and
        # raise once' logic of these methods as well, is intended to be provisional
        # and revisited when we've seen the compilation pipeline fail more clearly,
        # or maybe preferably iteratively improved as new failure modes are
        # encountered.
        raise TypeError(
            'Encountered a type error while checking `after_aggregate`; '
            'expected a type signature of the form '
            '`<<<s1,c1>,c2>,s3> -> <s6,s7,c6>`, where s1 matches {}, '
            'c1 matches {}, c2 matches {}, s3 matches {}, s6 matches '
            '{}, s7 matches {}, c6 matches {}, as defined in '
            '`canonical_form.CanonicalForm`. Encountered a type signature '
            '{}.'.format(previously_packed_types['s1_type'],
                         previously_packed_types['c1_type'],
                         previously_packed_types['c2_type'],
                         previously_packed_types['s3_type'],
                         previously_packed_types['s6_type'],
                         previously_packed_types['s7_type'],
                         previously_packed_types['c6_type'], type_spec))
    s4_type = tff.FederatedType([
        previously_packed_types['s1_type'].member,
        previously_packed_types['s3_type'].member
    ], tff.SERVER)
    s5_type = tff.FederatedType([
        previously_packed_types['s6_type'].member,
        previously_packed_types['s7_type'].member
    ], tff.SERVER)
    newly_determined_types = {}
    newly_determined_types['s4_type'] = s4_type
    newly_determined_types['s5_type'] = s5_type
    newly_determined_types['update_type'] = tff.FunctionType(
        s4_type.member, s5_type.member)
    c3_type = tff.FederatedType([
        previously_packed_types['c1_type'].member,
        previously_packed_types['c2_type'].member
    ], tff.CLIENTS)
    newly_determined_types['c3_type'] = c3_type
    return dict(
        itertools.chain(six.iteritems(previously_packed_types),
                        six.iteritems(newly_determined_types)))
Exemplo n.º 10
0
def check_and_pack_before_aggregate_type_signature(type_spec,
                                                   previously_packed_types):
    """Checks types inferred from `before_aggregate` and packs in `previously_packed_types`.

  After splitting the `after_broadcast` portion of a
  `tff.utils.IterativeProcess` into `before_aggregate` and `after_aggregate`,
  `before_aggregate` should have type signature
  `<<s1,c1>,c2> -> <c5,zero,accumulate,merge,report>`. This
  function validates `c1`, `s1` and `c2` against the existing entries in
  `previously_packed_types`, then packs `s5`, `zero`, `accumulate`, `merge` and
  `report`.

  Args:
    type_spec: The `type_signature` attribute of the `before_aggregate` portion
      of the `tff.utils.IterativeProcess` from which we are looking to extract
      an instance of `canonical_form.CanonicalForm`.
    previously_packed_types: Dict containing the information from `next` and
      `before_broadcast` in the iterative process we are parsing.

  Returns:
    A `dict` packing the types which can be inferred from `type_spec`.

  Raises:
    TypeError: If `type_signature` is incompatible with
    `previously_packed_types`.
  """
    should_raise = False
    if not (isinstance(type_spec, tff.FunctionType)
            and isinstance(type_spec.parameter, tff.NamedTupleType)):
        should_raise = True
    if not (isinstance(type_spec.parameter[0], tff.NamedTupleType)
            and len(type_spec.parameter[0]) == 2 and type_spec.parameter[0][0]
            == previously_packed_types['s1_type'] and type_spec.parameter[0][1]
            == previously_packed_types['c1_type']):
        should_raise = True
    if not (isinstance(type_spec.parameter[1], tff.FederatedType)
            and type_spec.parameter[1].placement == tff.CLIENTS
            and type_spec.parameter[1].member
            == previously_packed_types['s2_type'].member):
        should_raise = True
    if not (isinstance(type_spec.result, tff.NamedTupleType)
            and len(type_spec.result) == 5
            and isinstance(type_spec.result[0], tff.FederatedType)
            and type_spec.result[0].placement == tff.CLIENTS and
            tff_framework.is_tensorflow_compatible_type(type_spec.result[1])
            and type_spec.result[2] == tff.FunctionType(
                [type_spec.result[1], type_spec.result[0].member],
                type_spec.result[1]) and type_spec.result[3]
            == tff.FunctionType([type_spec.result[1], type_spec.result[1]],
                                type_spec.result[1])
            and type_spec.result[4].parameter == type_spec.result[1]
            and tff_framework.is_tensorflow_compatible_type(
                type_spec.result[4].result)):
        should_raise = True
    if should_raise:
        # TODO(b/121290421): These error messages, and indeed the 'track boolean and
        # raise once' logic of these methods as well, is intended to be provisional
        # and revisited when we've seen the compilation pipeline fail more clearly,
        # or maybe preferably iteratively improved as new failure modes are
        # encountered.
        raise TypeError(
            'Encountered a type error while checking '
            '`before_aggregate`. Expected a type signature of the '
            'form `<<s1,c1>,c2> -> <c5,zero,accumulate,merge,report>`, '
            'where `s1` matches {}, `c1` matches {}, and `c2` matches '
            'the result of broadcasting {}, as defined in '
            '`canonical_form.CanonicalForm`. Found type signature {}.'.format(
                previously_packed_types['s1_type'],
                previously_packed_types['c1_type'],
                previously_packed_types['s2_type'], type_spec))
    newly_determined_types = {}
    c2_type = type_spec.parameter[1]
    newly_determined_types['c2_type'] = c2_type
    c3_type = tff.FederatedType(
        [previously_packed_types['c1_type'].member, c2_type.member],
        tff.CLIENTS)
    newly_determined_types['c3_type'] = c3_type
    c5_type = type_spec.result[0]
    zero_type = tff.FunctionType(None, type_spec.result[1])
    accumulate_type = type_spec.result[2]
    merge_type = type_spec.result[3]
    report_type = type_spec.result[4]
    newly_determined_types['c5_type'] = c5_type
    newly_determined_types['zero_type'] = zero_type
    newly_determined_types['accumulate_type'] = accumulate_type
    newly_determined_types['merge_type'] = merge_type
    newly_determined_types['report_type'] = report_type
    newly_determined_types['s3_type'] = tff.FederatedType(
        report_type.result, tff.SERVER)
    c4_type = tff.FederatedType([
        newly_determined_types['c5_type'].member,
        previously_packed_types['c6_type'].member
    ], tff.CLIENTS)
    newly_determined_types['c4_type'] = c4_type
    newly_determined_types['work_type'] = tff.FunctionType(
        c3_type.member, c4_type.member)
    return dict(
        itertools.chain(six.iteritems(previously_packed_types),
                        six.iteritems(newly_determined_types)))
Exemplo n.º 11
0
  def __init__(self, initialize, prepare, work, zero, accumulate, merge, report,
               update):
    """Constructs a representation of a MapReduce-like iterative process.

    NOTE: All the computations supplied here as arguments must be TensorFlow
    computations, i.e., instances of `tff.Computation` constructed by the
    `tff.tf_computation` decorator/wrapper.

    Args:
      initialize: The computation that produces the initial server state.
      prepare: The computation that prepares the input for the clients.
      work: The client-side work computation.
      zero: The computation that produces the initial state for accumulators.
      accumulate: The computation that adds a client update to an accumulator.
      merge: The computation to use for merging pairs of accumulators.
      report: The computation that produces the final server-side aggregate for
        the top level accumulator (the global update).
      update: The computation that takes the global update and the server state
        and produces the new server state, as well as server-side output.

    Raises:
      TypeError: If the Python or TFF types of the arguments are invalid or not
        compatible with each other.
      AssertionError: If the manner in which the given TensorFlow computations
        are represented by TFF does not match what this code is expecting (this
        is an internal error that requires code update).
    """
    for label, comp in [
        ('initialize', initialize),
        ('prepare', prepare),
        ('work', work),
        ('zero', zero),
        ('accumulate', accumulate),
        ('merge', merge),
        ('report', report),
        ('update', update),
    ]:
      py_typecheck.check_type(comp, tff.Computation, label)

      # TODO(b/130633916): Remove private access once an appropriate API for it
      # becomes available.
      comp_proto = comp._computation_proto  # pylint: disable=protected-access

      if not isinstance(comp_proto, computation_pb2.Computation):
        # Explicitly raised to force it to be done in non-debug mode as well.
        raise AssertionError('Cannot find the embedded computation definition.')
      which_comp = comp_proto.WhichOneof('computation')
      if which_comp != 'tensorflow':
        raise TypeError('Expected all computations supplied as arguments to '
                        'be plain TensorFlow, found {}.'.format(which_comp))

    if prepare.type_signature.parameter != initialize.type_signature.result:
      raise TypeError(
          'The `prepare` computation expects an argument of type {}, '
          'which does not match the result type {} of `initialize`.'.format(
              prepare.type_signature.parameter,
              initialize.type_signature.result))

    if (not isinstance(work.type_signature.parameter, tff.NamedTupleType) or
        len(work.type_signature.parameter) != 2):
      raise TypeError(
          'The `work` computation expects an argument of type {} that is not '
          'a two-tuple.'.format(work.type_signature.parameter))

    if work.type_signature.parameter[1] != prepare.type_signature.result:
      raise TypeError(
          'The `work` computation expects an argument tuple with type {} as '
          'the second element (the initial client state from the server), '
          'which does not match the result type {} of `prepare`.'.format(
              work.type_signature.parameter[1], prepare.type_signature.result))

    if (not isinstance(work.type_signature.result, tff.NamedTupleType) or
        len(work.type_signature.result) != 2):
      raise TypeError(
          'The `work` computation returns a result  of type {} that is not a '
          'two-tuple.'.format(work.type_signature.result))

    expected_accumulate_type = tff.FunctionType(
        [zero.type_signature.result, work.type_signature.result[0]],
        zero.type_signature.result)
    if accumulate.type_signature != expected_accumulate_type:
      raise TypeError(
          'The `accumulate` computation has type signature {}, which does '
          'not match the expected {} as implied by the type signatures of '
          '`zero` and `work`.'.format(accumulate.type_signature,
                                      expected_accumulate_type))

    expected_merge_type = tff.FunctionType(
        [accumulate.type_signature.result, accumulate.type_signature.result],
        accumulate.type_signature.result)
    if merge.type_signature != expected_merge_type:
      raise TypeError(
          'The `merge` computation has type signature {}, which does '
          'not match the expected {} as implied by the type signature '
          'of `accumulate`.'.format(merge.type_signature, expected_merge_type))

    if report.type_signature.parameter != merge.type_signature.result:
      raise TypeError(
          'The `report` computation expects an argument of type {}, '
          'which does not match the result type {} of `merge`.'.format(
              report.type_signature.parameter, merge.type_signature.result))

    expected_update_parameter_type = tff.to_type(
        [initialize.type_signature.result, report.type_signature.result])
    if update.type_signature.parameter != expected_update_parameter_type:
      raise TypeError(
          'The `update` computation expects an argument of type {}, '
          'which does not match the expected {} as implied by the type '
          'signatures of `initialize` and `report`.'.format(
              update.type_signature.parameter, expected_update_parameter_type))

    if (not isinstance(update.type_signature.result, tff.NamedTupleType) or
        len(update.type_signature.result) != 2):
      raise TypeError(
          'The `update` computation returns a result  of type {} that is not '
          'a two-tuple.'.format(update.type_signature.result))

    if update.type_signature.result[0] != initialize.type_signature.result:
      raise TypeError(
          'The `update` computation returns a result tuple with type {} as '
          'the first element (the updated state of the server), which does '
          'not match the result type {} of `initialize`.'.format(
              update.type_signature.result[0],
              initialize.type_signature.result))

    self._initialize = initialize
    self._prepare = prepare
    self._work = work
    self._zero = zero
    self._accumulate = accumulate
    self._merge = merge
    self._report = report
    self._update = update