Example #1
0
  def test_two_param_next_raises(self):

    @federated_computation.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE)
    def next_fn(state, weights):
      return MeasuredProcessOutput(state, weights, server_zero())

    with self.assertRaises(errors.TemplateNextFnNumArgsError):
      finalizers.FinalizerProcess(test_initialize_fn, next_fn)
Example #2
0
  def test_next_return_tuple_raises(self):

    @federated_computation.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                                 SERVER_FLOAT)
    def tuple_next_fn(state, weights, update):
      return state, test_finalizer_result(weights, update), server_zero()

    with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError):
      finalizers.FinalizerProcess(test_initialize_fn, tuple_next_fn)
Example #3
0
  def test_non_server_placed_next_measurements_raises(self):

    @federated_computation.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                                 SERVER_FLOAT)
    def next_fn(state, weights, update):
      return MeasuredProcessOutput(
          state, test_finalizer_result(weights, update),
          intrinsics.federated_value(1.0, placements.CLIENTS))

    with self.assertRaises(errors.TemplatePlacementError):
      finalizers.FinalizerProcess(test_initialize_fn, next_fn)
Example #4
0
  def test_next_return_odict_raises(self):

    @federated_computation.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                                 SERVER_FLOAT)
    def odict_next_fn(state, weights, update):
      return collections.OrderedDict(
          state=state,
          result=test_finalizer_result(weights, update),
          measurements=server_zero())

    with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError):
      finalizers.FinalizerProcess(test_initialize_fn, odict_next_fn)
Example #5
0
  def test_init_tuple_of_federated_types_raises(self):
    initialize_fn = federated_computation.federated_computation()(
        lambda: (server_zero(), server_zero()))

    @federated_computation.federated_computation(
        initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, SERVER_FLOAT)
    def next_fn(state, weights, update):
      return MeasuredProcessOutput(state,
                                   test_finalizer_result(weights, update),
                                   server_zero())

    with self.assertRaises(errors.TemplateNotFederatedError):
      finalizers.FinalizerProcess(initialize_fn, next_fn)
Example #6
0
  def test_non_federated_init_next_raises(self):
    initialize_fn = tensorflow_computation.tf_computation(lambda: 0)

    @tensorflow_computation.tf_computation(
        tf.int32,
        computation_types.to_type(model_utils.ModelWeights(tf.float32,
                                                           ())), tf.float32)
    def next_fn(state, weights, update):
      new_weigths = model_utils.ModelWeights(weights.trainable + update, ())
      return MeasuredProcessOutput(state, new_weigths, 0)

    with self.assertRaises(errors.TemplateNotFederatedError):
      finalizers.FinalizerProcess(initialize_fn, next_fn)
Example #7
0
  def test_next_return_namedtuple_raises(self):
    measured_process_output = collections.namedtuple(
        'MeasuredProcessOutput', ['state', 'result', 'measurements'])

    @federated_computation.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                                 SERVER_FLOAT)
    def namedtuple_next_fn(state, weights, update):
      return measured_process_output(state,
                                     test_finalizer_result(weights, update),
                                     server_zero())

    with self.assertRaises(errors.TemplateNotMeasuredProcessOutputError):
      finalizers.FinalizerProcess(test_initialize_fn, namedtuple_next_fn)
Example #8
0
  def test_next_state_not_assignable(self):

    @federated_computation.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                                 SERVER_FLOAT)
    def float_next_fn(state, weights, update):
      del state
      return MeasuredProcessOutput(
          intrinsics.federated_value(0.0, placements.SERVER),
          test_finalizer_result(weights, update),
          intrinsics.federated_value(1, placements.SERVER))

    with self.assertRaises(errors.TemplateStateNotAssignableError):
      finalizers.FinalizerProcess(test_initialize_fn, float_next_fn)
Example #9
0
  def test_non_server_placed_next_weight_param_raises(self):

    @federated_computation.federated_computation(SERVER_INT,
                                                 computation_types.at_clients(
                                                     MODEL_WEIGHTS_TYPE.member),
                                                 SERVER_FLOAT)
    def next_fn(state, weights, update):
      return MeasuredProcessOutput(
          state,
          test_finalizer_result(intrinsics.federated_sum(weights), update),
          server_zero())

    with self.assertRaises(errors.TemplatePlacementError):
      finalizers.FinalizerProcess(test_initialize_fn, next_fn)
Example #10
0
def test_finalizer():
    @federated_computation.federated_computation(
        empty_init_fn.type_signature.result,
        computation_types.at_server(MODEL_WEIGHTS_TYPE),
        computation_types.at_server(FLOAT_TYPE))
    def next_fn(state, weights, updates):
        new_weights = intrinsics.federated_map(
            tensorflow_computation.tf_computation(lambda x, y: x + y),
            (weights.trainable, updates))
        new_weights = intrinsics.federated_zip(
            model_utils.ModelWeights(new_weights, ()))
        return measured_process.MeasuredProcessOutput(state, new_weights,
                                                      empty_at_server())

    return finalizers.FinalizerProcess(empty_init_fn, next_fn)
Example #11
0
  def test_construction_with_empty_state_does_not_raise(self):
    initialize_fn = federated_computation.federated_computation()(
        lambda: intrinsics.federated_value((), placements.SERVER))

    @federated_computation.federated_computation(
        initialize_fn.type_signature.result, MODEL_WEIGHTS_TYPE, SERVER_FLOAT)
    def next_fn(state, weights, update):
      return MeasuredProcessOutput(
          state, test_finalizer_result(weights, update),
          intrinsics.federated_value(1, placements.SERVER))

    try:
      finalizers.FinalizerProcess(initialize_fn, next_fn)
    except:  # pylint: disable=bare-except
      self.fail('Could not construct an FinalizerProcess with empty state.')
Example #12
0
  def test_result_not_assignable_to_weight_raises(self):
    bad_cast_fn = tensorflow_computation.tf_computation(
        lambda x: tf.nest.map_structure(lambda y: tf.cast(y, tf.float64), x))

    @federated_computation.federated_computation(SERVER_INT, MODEL_WEIGHTS_TYPE,
                                                 SERVER_FLOAT)
    def next_fn(state, weights, update):
      return MeasuredProcessOutput(
          state,
          intrinsics.federated_map(bad_cast_fn,
                                   test_finalizer_result(weights, update)),
          server_zero())

    with self.assertRaises(finalizers.FinalizerResultTypeError):
      finalizers.FinalizerProcess(test_initialize_fn, next_fn)
Example #13
0
  def test_constructs_with_non_model_weights_parameter(self):
    non_model_weights_type = computation_types.at_server(
        computation_types.to_type(
            collections.OrderedDict(trainable=tf.float32, non_trainable=())))

    @federated_computation.federated_computation(SERVER_INT,
                                                 non_model_weights_type,
                                                 SERVER_FLOAT)
    def next_fn(state, weights, update):
      del update
      return MeasuredProcessOutput(state, weights, server_zero())

    try:
      finalizers.FinalizerProcess(test_initialize_fn, next_fn)
    except:  # pylint: disable=bare-except
      self.fail('Could not construct a valid FinalizerProcess.')
Example #14
0
def _build_kmeans_finalizer(centroids_type: computation_types.Type,
                            num_centroids: int):
    """Builds a `tff.learning.templates.FinalizerProcess` for k-means."""
    @tensorflow_computation.tf_computation
    def initialize_weights():
        return tf.ones((num_centroids, ), dtype=_WEIGHT_DTYPE)

    @federated_computation.federated_computation
    def init_fn():
        return intrinsics.federated_eval(initialize_weights, placements.SERVER)

    weights_type = initialize_weights.type_signature.result

    @tensorflow_computation.tf_computation(centroids_type, weights_type,
                                           centroids_type, weights_type)
    def server_update_tf(current_centroids, current_weights, new_centroid_sums,
                         new_weights):
        return _update_centroids(current_centroids, current_weights,
                                 new_centroid_sums, new_weights)

    summed_updates_type = computation_types.at_server(
        computation_types.to_type((centroids_type, weights_type)))

    @federated_computation.federated_computation(
        init_fn.type_signature.result,
        computation_types.at_server(centroids_type), summed_updates_type)
    def next_fn(state, current_centroids, summed_updates):
        new_centroid_sums, new_weights = summed_updates
        updated_centroids, updated_weights = intrinsics.federated_map(
            server_update_tf,
            (current_centroids, state, new_centroid_sums, new_weights))
        empty_measurements = intrinsics.federated_value((), placements.SERVER)
        return measured_process.MeasuredProcessOutput(updated_weights,
                                                      updated_centroids,
                                                      empty_measurements)

    return finalizers.FinalizerProcess(init_fn, next_fn)
Example #15
0
 def test_init_param_not_empty_raises(self):
   one_arg_initialize_fn = federated_computation.federated_computation(
       SERVER_INT)(lambda x: x)
   with self.assertRaises(errors.TemplateInitFnParamNotEmptyError):
     finalizers.FinalizerProcess(one_arg_initialize_fn, test_next_fn)
Example #16
0
 def test_next_not_tff_computation_raises(self):
   with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'):
     finalizers.FinalizerProcess(
         initialize_fn=test_initialize_fn,
         next_fn=lambda state, w, u: MeasuredProcessOutput(state, w + u, ()))
Example #17
0
 def test_init_not_tff_computation_raises(self):
   with self.assertRaisesRegex(TypeError, r'Expected .*\.Computation, .*'):
     finalizers.FinalizerProcess(initialize_fn=lambda: 0, next_fn=test_next_fn)
Example #18
0
 def test_construction_does_not_raise(self):
   try:
     finalizers.FinalizerProcess(test_initialize_fn, test_next_fn)
   except:  # pylint: disable=bare-except
     self.fail('Could not construct a valid FinalizerProcess.')
Example #19
0
 def test_init_state_not_assignable(self):
   float_initialize_fn = federated_computation.federated_computation()(
       lambda: intrinsics.federated_value(0.0, placements.SERVER))
   with self.assertRaises(errors.TemplateStateNotAssignableError):
     finalizers.FinalizerProcess(float_initialize_fn, test_next_fn)