def test_build_encoded_mean_raise_warning(self): value = tf.constant(np.random.rand(20)) value_spec = tf.TensorSpec(value.shape, tf.dtypes.as_dtype(value.dtype)) encoder = te.encoders.as_gather_encoder(te.encoders.identity(), value_spec) with warnings.catch_warnings(record=True) as w: warnings.simplefilter('always') encoding_utils.build_encoded_mean(value, encoder) self.assertLen(w, 1)
def test_build_encoded_mean_raise_warning(self): value = tf.constant(np.random.rand(20)) value_spec = tf.TensorSpec(value.shape, tf.dtypes.as_dtype(value.dtype)) encoder = te.encoders.as_gather_encoder(te.encoders.identity(), value_spec) with warnings.catch_warnings(record=True): warnings.simplefilter('error', DeprecationWarning) with self.assertRaisesRegex(DeprecationWarning, 'tff.utils.build_encoded_mean()'): encoding_utils.build_encoded_mean(value, encoder)
def test_run_encoded_mean(self): value = np.array([0.0, 1.0, 2.0, -1.0]) value_spec = tf.TensorSpec(value.shape, tf.dtypes.as_dtype(value.dtype)) value_type = computation_types.to_type(value_spec) encoder = te.encoders.as_gather_encoder(te.encoders.identity(), value_spec) gather_fn = encoding_utils.build_encoded_mean(value, encoder) initial_state = gather_fn.initialize() @computations.federated_computation( computation_types.FederatedType( gather_fn._initialize_fn.type_signature.result, placements.SERVER), computation_types.FederatedType(value_type, placements.CLIENTS), computation_types.FederatedType( computation_types.to_type(tf.float32), placements.CLIENTS)) def call_gather(state, value, weight): return gather_fn(state, value, weight) _, value_mean = call_gather(initial_state, [value, value], [1.0, 1.0]) self.assertAllClose(1 * value, value_mean) _, value_mean = call_gather(initial_state, [value, value], [0.3, 0.7]) self.assertAllClose(1 * value, value_mean) _, value_mean = call_gather(initial_state, [value, 2 * value], [1.0, 2.0]) self.assertAllClose(5 / 3 * value, value_mean)
def build_encoded_mean_from_model( model_fn: _ModelConstructor, encoder_fn: _EncoderConstructor) -> computation_utils.StatefulAggregateFn: """Builds `StatefulAggregateFn` for weights of model returned by `model_fn`. This method creates a `GatherEncoder` for every trainable weight of model created by `model_fn`, as returned by `encoder_fn`. Args: model_fn: A Python callable with no arguments function that returns a `tff.learning.Model`. encoder_fn: A Python callable with a single argument, which is expected to be a `tf.Tensor` of shape and dtype to be encoded. The function must return a `tensor_encoding.core.SimpleEncoder`, which expects a `tf.Tensor` with compatible type as the input to its `encode` method. Returns: A `StatefulAggregateFn` for encoding and averaging the weights of model created by `model_fn`. Raises: TypeError: If `model_fn` or `encoder_fn` are not callable objects. """ warnings.warn( 'Deprecation warning: ' 'tff.learning.framework.build_encoded_mean_from_model() is deprecated, ' 'use tff.learning.framework.build_encoded_mean_process_from_model() ' 'instead.', DeprecationWarning) py_typecheck.check_callable(model_fn) py_typecheck.check_callable(encoder_fn) trainable_weights = _weights_from_model_fn(model_fn).trainable encoders = tf.nest.map_structure(encoder_fn, trainable_weights) return encoding_utils.build_encoded_mean(trainable_weights, encoders)
def test_build_encoded_mean(self, value_constructor, encoder_constructor): value = value_constructor(np.random.rand(20)) value_spec = tf.TensorSpec(value.shape, tf.as_dtype(value.dtype)) value_type = tff.to_type(value_spec) encoder = te.encoders.as_gather_encoder(encoder_constructor(), value_spec) gather_fn = encoding_utils.build_encoded_mean(value, encoder) state_type = gather_fn._initialize_fn.type_signature.result gather_signature = tff.federated_computation( gather_fn._next_fn, tff.FederatedType(state_type, tff.SERVER), tff.FederatedType(value_type, tff.CLIENTS), tff.FederatedType(tff.to_type(tf.float32), tff.CLIENTS)).type_signature self.assertIsInstance(gather_fn, StatefulAggregateFn) self.assertEqual(state_type, gather_signature.result[0].member) self.assertEqual(tff.SERVER, gather_signature.result[0].placement) self.assertEqual(value_type, gather_signature.result[1].member) self.assertEqual(tff.SERVER, gather_signature.result[1].placement)
def test_build_encoded_mean_raises_bad_structure(self): value = [tf.constant([0.0, 1.0]), tf.constant([0.0, 1.0])] encoder = te.encoders.as_gather_encoder(te.encoders.identity(), tf.TensorSpec((2,))) with self.assertRaises(ValueError): encoding_utils.build_encoded_mean(value, encoder)
def test_build_encoded_mean_raises_incompatible_encoder(self): value = tf.constant([0.0, 1.0]) incompatible_encoder = te.encoders.as_gather_encoder( te.encoders.identity(), tf.TensorSpec((3,))) with self.assertRaises(TypeError): encoding_utils.build_encoded_mean(value, incompatible_encoder)
def test_build_encoded_mean_raises_bad_encoder(self, bad_encoder): value = tf.constant([0.0, 1.0]) with self.assertRaises(TypeError): encoding_utils.build_encoded_mean(value, bad_encoder)