Beispiel #1
0
  def test_build_encoded_mean_raise_warning(self):
    value = tf.constant(np.random.rand(20))
    value_spec = tf.TensorSpec(value.shape, tf.dtypes.as_dtype(value.dtype))
    encoder = te.encoders.as_gather_encoder(te.encoders.identity(), value_spec)

    with warnings.catch_warnings(record=True) as w:
      warnings.simplefilter('always')
      encoding_utils.build_encoded_mean(value, encoder)
      self.assertLen(w, 1)
  def test_build_encoded_mean_raise_warning(self):
    value = tf.constant(np.random.rand(20))
    value_spec = tf.TensorSpec(value.shape, tf.dtypes.as_dtype(value.dtype))
    encoder = te.encoders.as_gather_encoder(te.encoders.identity(), value_spec)

    with warnings.catch_warnings(record=True):
      warnings.simplefilter('error', DeprecationWarning)
      with self.assertRaisesRegex(DeprecationWarning,
                                  'tff.utils.build_encoded_mean()'):
        encoding_utils.build_encoded_mean(value, encoder)
  def test_run_encoded_mean(self):
    value = np.array([0.0, 1.0, 2.0, -1.0])
    value_spec = tf.TensorSpec(value.shape, tf.dtypes.as_dtype(value.dtype))
    value_type = computation_types.to_type(value_spec)
    encoder = te.encoders.as_gather_encoder(te.encoders.identity(), value_spec)
    gather_fn = encoding_utils.build_encoded_mean(value, encoder)
    initial_state = gather_fn.initialize()

    @computations.federated_computation(
        computation_types.FederatedType(
            gather_fn._initialize_fn.type_signature.result, placements.SERVER),
        computation_types.FederatedType(value_type, placements.CLIENTS),
        computation_types.FederatedType(
            computation_types.to_type(tf.float32), placements.CLIENTS))
    def call_gather(state, value, weight):
      return gather_fn(state, value, weight)

    _, value_mean = call_gather(initial_state, [value, value], [1.0, 1.0])
    self.assertAllClose(1 * value, value_mean)

    _, value_mean = call_gather(initial_state, [value, value], [0.3, 0.7])
    self.assertAllClose(1 * value, value_mean)

    _, value_mean = call_gather(initial_state, [value, 2 * value], [1.0, 2.0])
    self.assertAllClose(5 / 3 * value, value_mean)
Beispiel #4
0
def build_encoded_mean_from_model(
    model_fn: _ModelConstructor,
    encoder_fn: _EncoderConstructor) -> computation_utils.StatefulAggregateFn:
  """Builds `StatefulAggregateFn` for weights of model returned by `model_fn`.

  This method creates a `GatherEncoder` for every trainable weight of model
  created by `model_fn`, as returned by `encoder_fn`.

  Args:
    model_fn: A Python callable with no arguments function that returns a
      `tff.learning.Model`.
    encoder_fn: A Python callable with a single argument, which is expected to
      be a `tf.Tensor` of shape and dtype to be encoded. The function must
      return a `tensor_encoding.core.SimpleEncoder`, which expects a `tf.Tensor`
      with compatible type as the input to its `encode` method.

  Returns:
    A `StatefulAggregateFn` for encoding and averaging the weights of model
    created by `model_fn`.

  Raises:
    TypeError: If `model_fn` or `encoder_fn` are not callable objects.
  """
  warnings.warn(
      'Deprecation warning: '
      'tff.learning.framework.build_encoded_mean_from_model() is deprecated, '
      'use tff.learning.framework.build_encoded_mean_process_from_model() '
      'instead.', DeprecationWarning)

  py_typecheck.check_callable(model_fn)
  py_typecheck.check_callable(encoder_fn)
  trainable_weights = _weights_from_model_fn(model_fn).trainable
  encoders = tf.nest.map_structure(encoder_fn, trainable_weights)
  return encoding_utils.build_encoded_mean(trainable_weights, encoders)
Beispiel #5
0
    def test_build_encoded_mean(self, value_constructor, encoder_constructor):
        value = value_constructor(np.random.rand(20))
        value_spec = tf.TensorSpec(value.shape, tf.as_dtype(value.dtype))
        value_type = tff.to_type(value_spec)
        encoder = te.encoders.as_gather_encoder(encoder_constructor(),
                                                value_spec)
        gather_fn = encoding_utils.build_encoded_mean(value, encoder)
        state_type = gather_fn._initialize_fn.type_signature.result
        gather_signature = tff.federated_computation(
            gather_fn._next_fn, tff.FederatedType(state_type, tff.SERVER),
            tff.FederatedType(value_type, tff.CLIENTS),
            tff.FederatedType(tff.to_type(tf.float32),
                              tff.CLIENTS)).type_signature

        self.assertIsInstance(gather_fn, StatefulAggregateFn)
        self.assertEqual(state_type, gather_signature.result[0].member)
        self.assertEqual(tff.SERVER, gather_signature.result[0].placement)
        self.assertEqual(value_type, gather_signature.result[1].member)
        self.assertEqual(tff.SERVER, gather_signature.result[1].placement)
 def test_build_encoded_mean_raises_bad_structure(self):
   value = [tf.constant([0.0, 1.0]), tf.constant([0.0, 1.0])]
   encoder = te.encoders.as_gather_encoder(te.encoders.identity(),
                                           tf.TensorSpec((2,)))
   with self.assertRaises(ValueError):
     encoding_utils.build_encoded_mean(value, encoder)
 def test_build_encoded_mean_raises_incompatible_encoder(self):
   value = tf.constant([0.0, 1.0])
   incompatible_encoder = te.encoders.as_gather_encoder(
       te.encoders.identity(), tf.TensorSpec((3,)))
   with self.assertRaises(TypeError):
     encoding_utils.build_encoded_mean(value, incompatible_encoder)
 def test_build_encoded_mean_raises_bad_encoder(self, bad_encoder):
   value = tf.constant([0.0, 1.0])
   with self.assertRaises(TypeError):
     encoding_utils.build_encoded_mean(value, bad_encoder)