예제 #1
0
    def test_run_encoded_mean(self):
        value = np.array([0.0, 1.0, 2.0, -1.0])
        value_spec = tf.TensorSpec(value.shape, tf.as_dtype(value.dtype))
        value_type = tff.to_type(value_spec)
        encoder = te.encoders.as_gather_encoder(te.encoders.identity(),
                                                value_spec)
        gather_fn = encoding_utils.build_encoded_mean(value, encoder)
        initial_state = gather_fn.initialize()

        @tff.federated_computation(
            tff.FederatedType(gather_fn._initialize_fn.type_signature.result,
                              tff.SERVER),
            tff.FederatedType(value_type, tff.CLIENTS),
            tff.FederatedType(tff.to_type(tf.float32), tff.CLIENTS))
        def call_gather(state, value, weight):
            return gather_fn(state, value, weight)

        _, value_mean = call_gather(initial_state, [value, value], [1.0, 1.0])
        self.assertAllClose(1 * value, value_mean)

        _, value_mean = call_gather(initial_state, [value, value], [0.3, 0.7])
        self.assertAllClose(1 * value, value_mean)

        _, value_mean = call_gather(initial_state, [value, 2 * value],
                                    [1.0, 2.0])
        self.assertAllClose(5 / 3 * value, value_mean)
예제 #2
0
def build_encoded_mean(values, encoders):
    """Builds `StatefulAggregateFn` for `values`, to be encoded by `encoders`.

  Args:
    values: Values to be encoded by the `StatefulAggregateFn`. Must be
      convertible to `tff.Value`.
    encoders: A collection of `GatherEncoder` objects to be used for encoding
      `values`. Must have the same structure as `values`.

  Returns:
    A `StatefulAggregateFn` of which `next_fn` encodes the input at
    `tff.CLIENTS`, and computes their mean at `tff.SERVER`, automatically
    splitting the decoding part based on its commutativity with sum.

  Raises:
    ValueError: If `values` and `encoders` do not have the same structure.
    TypeError: If `encoders` are not instances of `GatherEncoder`, or if
      `values` are not compatible with the expected input of the `encoders`.
  """

    tf.nest.assert_same_structure(values, encoders)
    tf.nest.map_structure(
        lambda e, v: _validate_encoder(e, v, te.core.GatherEncoder), encoders,
        values)

    value_type = tff_framework.type_from_tensors(values)

    initial_state_fn = _build_initial_state_tf_computation(encoders)
    state_type = initial_state_fn.type_signature.result

    nest_encoder = _build_tf_computations_for_gather(state_type, value_type,
                                                     encoders)
    encoded_sum_fn = _build_encoded_sum_fn(nest_encoder)

    @tff.tf_computation(value_type, tff.to_type(tf.float32))
    def multiply_fn(value, weight):
        return tf.nest.map_structure(lambda v: v * tf.cast(weight, v.dtype),
                                     value)

    @tff.tf_computation(value_type, tff.to_type(tf.float32))
    def divide_fn(value, denominator):
        return tf.nest.map_structure(
            lambda v: v / tf.cast(denominator, v.dtype), value)

    def encoded_mean_fn(state, values, weight):
        weighted_values = tff.federated_map(multiply_fn, [values, weight])
        updated_state, summed_decoded_values = encoded_sum_fn(
            state, weighted_values)
        summed_weights = tff.federated_sum(weight)
        decoded_values = tff.federated_map(
            divide_fn, [summed_decoded_values, summed_weights])
        return updated_state, decoded_values

    return StatefulAggregateFn(initialize_fn=initial_state_fn,
                               next_fn=encoded_mean_fn)
예제 #3
0
    def test_build_encoded_sum(self, value_constructor, encoder_constructor):
        value = value_constructor(np.random.rand(20))
        value_spec = tf.TensorSpec(value.shape, tf.as_dtype(value.dtype))
        value_type = tff.to_type(value_spec)
        encoder = te.encoders.as_gather_encoder(encoder_constructor(),
                                                value_spec)
        gather_fn = encoding_utils.build_encoded_sum(value, encoder)
        state_type = gather_fn._initialize_fn.type_signature.result
        gather_signature = tff.federated_computation(
            gather_fn._next_fn, tff.FederatedType(state_type, tff.SERVER),
            tff.FederatedType(value_type, tff.CLIENTS),
            tff.FederatedType(tff.to_type(tf.float32),
                              tff.CLIENTS)).type_signature

        self.assertIsInstance(gather_fn, StatefulAggregateFn)
        self.assertEqual(state_type, gather_signature.result[0].member)
        self.assertEqual(tff.SERVER, gather_signature.result[0].placement)
        self.assertEqual(value_type, gather_signature.result[1].member)
        self.assertEqual(tff.SERVER, gather_signature.result[1].placement)
예제 #4
0
  def test_build_encoded_broadcast(self, value_constructor,
                                   encoder_constructor):
    value = value_constructor(np.random.rand(20))
    value_spec = tf.TensorSpec(value.shape, tf.as_dtype(value.dtype))
    value_type = tff.to_type(value_spec)
    encoder = te.core.SimpleEncoder(encoder_constructor(), value_spec)
    broadcast_fn = encoding_utils.build_encoded_broadcast(value, encoder)
    broadcast_signature = broadcast_fn._next_fn.type_signature

    self.assertIsInstance(broadcast_fn, StatefulBroadcastFn)
    self.assertEqual(value_type, broadcast_signature.parameter[1].member)
    self.assertEqual(value_type, broadcast_signature.result[1].member)
    self.assertEqual(broadcast_signature.parameter[1].placement, tff.SERVER)
    self.assertEqual(broadcast_signature.result[1].placement, tff.CLIENTS)
예제 #5
0
def _extract_intrinsic_as_reference_to_top_level_lambda(comp, uri):
  """Extracts an intrinsic from `comp` as a reference for the given `uri`.

  Args:
    comp: The `tff_framework.Lambda` to transform. The names of lambda
      parameters and block variables in `comp` must be unique.
    uri: A URI of an intrinsic.

  Returns:
    A new computation with the transformation applied or the original `comp`.

  Raises:
    ValueError: If there is more than one intrinsic for the give `uri` or if the
      intrinsic is not exclusively bound by `comp`.
  """
  py_typecheck.check_type(comp, tff_framework.Lambda)
  tff_framework.check_has_unique_names(comp)
  py_typecheck.check_type(uri, six.string_types)
  intrinsics = _get_called_intrinsics(comp, uri)
  length = len(intrinsics)
  if length != 1:
    raise ValueError(
        'Expected a computation with exactly one intrinsic with the uri: {}, '
        'found: {}.'.format(uri, length))
  if not _are_comps_bound_exclusively_by_top_level_lambda(comp, intrinsics):
    raise ValueError(
        'Expected a computation which binds all the references in the '
        'intrinsic with the uri: {}.'.format(uri))
  name_generator = tff_framework.unique_name_generator(comp)
  extracted_intrinsic = intrinsics[0]
  ref_name = six.next(name_generator)
  ref_type = tff.to_type(extracted_intrinsic.type_signature)
  ref = tff_framework.Reference(ref_name, ref_type)

  def _should_transform(comp):
    return tff_framework.is_called_intrinsic(comp, uri)

  def _transform(comp):
    if not _should_transform(comp):
      return comp, False
    return ref, True

  comp, _ = tff_framework.transform_postorder(comp, _transform)
  comp = _insert_comp_in_top_level_lambda(
      comp, name=ref.name, comp_to_insert=extracted_intrinsic)
  return comp, True
예제 #6
0
    def test_build_encoded_broadcast(self, value_constructor,
                                     encoder_constructor):
        value = value_constructor(np.random.rand(20))
        value_spec = tf.TensorSpec(value.shape, tf.as_dtype(value.dtype))
        value_type = tff.to_type(value_spec)
        encoder = te.encoders.as_simple_encoder(encoder_constructor(),
                                                value_spec)
        broadcast_fn = encoding_utils.build_encoded_broadcast(value, encoder)
        state_type = broadcast_fn._initialize_fn.type_signature.result
        broadcast_signature = tff.federated_computation(
            broadcast_fn._next_fn,
            tff.FederatedType(
                broadcast_fn._initialize_fn.type_signature.result, tff.SERVER),
            tff.FederatedType(value_type, tff.SERVER)).type_signature

        self.assertIsInstance(broadcast_fn, StatefulBroadcastFn)
        self.assertEqual(state_type, broadcast_signature.result[0].member)
        self.assertEqual(tff.SERVER, broadcast_signature.result[0].placement)
        self.assertEqual(value_type, broadcast_signature.result[1].member)
        self.assertEqual(tff.CLIENTS, broadcast_signature.result[1].placement)
예제 #7
0
    def test_build_encode_decode_tf_computations_for_broadcast(
            self, encoder_constructor):
        value_spec = tf.TensorSpec((20, ), tf.float32)
        encoder = te.encoders.as_simple_encoder(encoder_constructor(),
                                                value_spec)

        _, state_type = encoding_utils._build_initial_state_tf_computation(
            encoder)
        value_type = tff.to_type(value_spec)
        encode_fn, decode_fn = (
            encoding_utils._build_encode_decode_tf_computations_for_broadcast(
                state_type, value_type, encoder))

        self.assertEqual(state_type, encode_fn.type_signature.parameter[0])
        self.assertEqual(state_type, encode_fn.type_signature.result[0])
        # Output of encode should be the input to decode.
        self.assertEqual(encode_fn.type_signature.result[1],
                         decode_fn.type_signature.parameter)
        # Decode should return the same type as input to encode - value_type.
        self.assertEqual(value_type, encode_fn.type_signature.parameter[1])
        self.assertEqual(value_type, decode_fn.type_signature.result)
예제 #8
0
    def test_build_tf_computations_for_sum(self, encoder_constructor):
        # Tests that the partial computations have matching relevant input-output
        # signatures.
        value_spec = tf.TensorSpec((20, ), tf.float32)
        encoder = te.encoders.as_gather_encoder(encoder_constructor(),
                                                value_spec)

        initial_state_fn = encoding_utils._build_initial_state_tf_computation(
            encoder)
        state_type = initial_state_fn.type_signature.result
        value_type = tff.to_type(value_spec)
        nest_encoder = encoding_utils._build_tf_computations_for_gather(
            state_type, value_type, encoder)

        self.assertEqual(state_type,
                         nest_encoder.get_params_fn.type_signature.parameter)
        encode_params_type = nest_encoder.get_params_fn.type_signature.result[
            0]
        decode_before_sum_params_type = nest_encoder.get_params_fn.type_signature.result[
            1]
        decode_after_sum_params_type = nest_encoder.get_params_fn.type_signature.result[
            2]

        self.assertEqual(value_type,
                         nest_encoder.encode_fn.type_signature.parameter[0])
        self.assertEqual(encode_params_type,
                         nest_encoder.encode_fn.type_signature.parameter[1])
        self.assertEqual(decode_before_sum_params_type,
                         nest_encoder.encode_fn.type_signature.parameter[2])
        state_update_tensors_type = nest_encoder.encode_fn.type_signature.result[
            2]

        accumulator_type = nest_encoder.zero_fn.type_signature.result
        self.assertEqual(state_update_tensors_type,
                         accumulator_type.state_update_tensors)

        self.assertEqual(
            accumulator_type,
            nest_encoder.accumulate_fn.type_signature.parameter[0])
        self.assertEqual(
            nest_encoder.encode_fn.type_signature.result,
            nest_encoder.accumulate_fn.type_signature.parameter[1])
        self.assertEqual(accumulator_type,
                         nest_encoder.accumulate_fn.type_signature.result)
        self.assertEqual(accumulator_type,
                         nest_encoder.merge_fn.type_signature.parameter[0])
        self.assertEqual(accumulator_type,
                         nest_encoder.merge_fn.type_signature.parameter[1])
        self.assertEqual(accumulator_type,
                         nest_encoder.merge_fn.type_signature.result)
        self.assertEqual(accumulator_type,
                         nest_encoder.report_fn.type_signature.parameter)
        self.assertEqual(accumulator_type,
                         nest_encoder.report_fn.type_signature.result)

        self.assertEqual(
            accumulator_type.values,
            nest_encoder.decode_after_sum_fn.type_signature.parameter[0])
        self.assertEqual(
            decode_after_sum_params_type,
            nest_encoder.decode_after_sum_fn.type_signature.parameter[1])
        self.assertEqual(
            value_type, nest_encoder.decode_after_sum_fn.type_signature.result)

        self.assertEqual(
            state_type,
            nest_encoder.update_state_fn.type_signature.parameter[0])
        self.assertEqual(
            state_update_tensors_type,
            nest_encoder.update_state_fn.type_signature.parameter[1])
        self.assertEqual(state_type,
                         nest_encoder.update_state_fn.type_signature.result)
예제 #9
0
  def __init__(self, initialize, prepare, work, zero, accumulate, merge, report,
               update):
    """Constructs a representation of a MapReduce-like iterative process.

    NOTE: All the computations supplied here as arguments must be TensorFlow
    computations, i.e., instances of `tff.Computation` constructed by the
    `tff.tf_computation` decorator/wrapper.

    Args:
      initialize: The computation that produces the initial server state.
      prepare: The computation that prepares the input for the clients.
      work: The client-side work computation.
      zero: The computation that produces the initial state for accumulators.
      accumulate: The computation that adds a client update to an accumulator.
      merge: The computation to use for merging pairs of accumulators.
      report: The computation that produces the final server-side aggregate for
        the top level accumulator (the global update).
      update: The computation that takes the global update and the server state
        and produces the new server state, as well as server-side output.

    Raises:
      TypeError: If the Python or TFF types of the arguments are invalid or not
        compatible with each other.
      AssertionError: If the manner in which the given TensorFlow computations
        are represented by TFF does not match what this code is expecting (this
        is an internal error that requires code update).
    """
    for label, comp in [
        ('initialize', initialize),
        ('prepare', prepare),
        ('work', work),
        ('zero', zero),
        ('accumulate', accumulate),
        ('merge', merge),
        ('report', report),
        ('update', update),
    ]:
      py_typecheck.check_type(comp, tff.Computation, label)

      # TODO(b/130633916): Remove private access once an appropriate API for it
      # becomes available.
      comp_proto = comp._computation_proto  # pylint: disable=protected-access

      if not isinstance(comp_proto, computation_pb2.Computation):
        # Explicitly raised to force it to be done in non-debug mode as well.
        raise AssertionError('Cannot find the embedded computation definition.')
      which_comp = comp_proto.WhichOneof('computation')
      if which_comp != 'tensorflow':
        raise TypeError('Expected all computations supplied as arguments to '
                        'be plain TensorFlow, found {}.'.format(which_comp))

    if prepare.type_signature.parameter != initialize.type_signature.result:
      raise TypeError(
          'The `prepare` computation expects an argument of type {}, '
          'which does not match the result type {} of `initialize`.'.format(
              prepare.type_signature.parameter,
              initialize.type_signature.result))

    if (not isinstance(work.type_signature.parameter, tff.NamedTupleType) or
        len(work.type_signature.parameter) != 2):
      raise TypeError(
          'The `work` computation expects an argument of type {} that is not '
          'a two-tuple.'.format(work.type_signature.parameter))

    if work.type_signature.parameter[1] != prepare.type_signature.result:
      raise TypeError(
          'The `work` computation expects an argument tuple with type {} as '
          'the second element (the initial client state from the server), '
          'which does not match the result type {} of `prepare`.'.format(
              work.type_signature.parameter[1], prepare.type_signature.result))

    if (not isinstance(work.type_signature.result, tff.NamedTupleType) or
        len(work.type_signature.result) != 2):
      raise TypeError(
          'The `work` computation returns a result  of type {} that is not a '
          'two-tuple.'.format(work.type_signature.result))

    expected_accumulate_type = tff.FunctionType(
        [zero.type_signature.result, work.type_signature.result[0]],
        zero.type_signature.result)
    if accumulate.type_signature != expected_accumulate_type:
      raise TypeError(
          'The `accumulate` computation has type signature {}, which does '
          'not match the expected {} as implied by the type signatures of '
          '`zero` and `work`.'.format(accumulate.type_signature,
                                      expected_accumulate_type))

    expected_merge_type = tff.FunctionType(
        [accumulate.type_signature.result, accumulate.type_signature.result],
        accumulate.type_signature.result)
    if merge.type_signature != expected_merge_type:
      raise TypeError(
          'The `merge` computation has type signature {}, which does '
          'not match the expected {} as implied by the type signature '
          'of `accumulate`.'.format(merge.type_signature, expected_merge_type))

    if report.type_signature.parameter != merge.type_signature.result:
      raise TypeError(
          'The `report` computation expects an argument of type {}, '
          'which does not match the result type {} of `merge`.'.format(
              report.type_signature.parameter, merge.type_signature.result))

    expected_update_parameter_type = tff.to_type(
        [initialize.type_signature.result, report.type_signature.result])
    if update.type_signature.parameter != expected_update_parameter_type:
      raise TypeError(
          'The `update` computation expects an argument of type {}, '
          'which does not match the expected {} as implied by the type '
          'signatures of `initialize` and `report`.'.format(
              update.type_signature.parameter, expected_update_parameter_type))

    if (not isinstance(update.type_signature.result, tff.NamedTupleType) or
        len(update.type_signature.result) != 2):
      raise TypeError(
          'The `update` computation returns a result  of type {} that is not '
          'a two-tuple.'.format(update.type_signature.result))

    if update.type_signature.result[0] != initialize.type_signature.result:
      raise TypeError(
          'The `update` computation returns a result tuple with type {} as '
          'the first element (the updated state of the server), which does '
          'not match the result type {} of `initialize`.'.format(
              update.type_signature.result[0],
              initialize.type_signature.result))

    self._initialize = initialize
    self._prepare = prepare
    self._work = work
    self._zero = zero
    self._accumulate = accumulate
    self._merge = merge
    self._report = report
    self._update = update