def test_build_encoded_broadcast_process_raises_incompatible_encoder(self): value_type = computation_types.TensorType(tf.float32, shape=[2]) incompatible_encoder = te.encoders.as_simple_encoder( te.encoders.identity(), tf.TensorSpec((3, ))) with self.assertRaises(TypeError): encoding_utils.build_encoded_broadcast_process( value_type, incompatible_encoder)
def test_build_encoded_broadcast_process_raises_bad_structure(self): value_type = computation_types.StructType([ computation_types.TensorType(tf.float32, shape=[2]), computation_types.TensorType(tf.float32, shape=[2]) ]) encoder = te.encoders.as_simple_encoder(te.encoders.identity(), tf.TensorSpec((2, ))) with self.assertRaises(ValueError): encoding_utils.build_encoded_broadcast_process(value_type, encoder)
def build_encoded_broadcast_process_from_model( model_fn: _ModelConstructor, encoder_fn: _EncoderConstructor) -> measured_process.MeasuredProcess: """Builds `MeasuredProcess` for weights of model returned by `model_fn`. This method creates a `SimpleEncoder` for every weight of model created by `model_fn`, as returned by `encoder_fn`. Args: model_fn: A Python callable with no arguments function that returns a `tff.learning.Model`. encoder_fn: A Python callable with a single argument, which is expected to be a `tf.Tensor` of shape and dtype to be encoded. The function must return a `tensor_encoding.core.SimpleEncoder`, which expects a `tf.Tensor` with compatible type as the input to its `encode` method. Returns: A `MeasuredProcess` for encoding and broadcasting the weights of model created by `model_fn`. Raises: TypeError: If `model_fn` or `encoder_fn` are not callable objects. """ py_typecheck.check_callable(model_fn) py_typecheck.check_callable(encoder_fn) weights = _weights_from_model_fn(model_fn) encoders = tf.nest.map_structure(encoder_fn, weights) weight_type = type_conversions.type_from_tensors(weights) return encoding_utils.build_encoded_broadcast_process( weight_type, encoders)
def test_build_encoded_broadcast_process(self, value_constructor, encoder_constructor): value = value_constructor(np.random.rand(20)) value_spec = tf.TensorSpec(value.shape, tf.dtypes.as_dtype(value.dtype)) value_type = computation_types.to_type(value_spec) encoder = te.encoders.as_simple_encoder(encoder_constructor(), value_spec) broadcast_process = encoding_utils.build_encoded_broadcast_process( value_type, encoder) state_type = broadcast_process._initialize_fn.type_signature.result broadcast_signature = broadcast_process._next_fn.type_signature self.assertIsInstance(broadcast_process, MeasuredProcess) self.assertEqual(state_type, broadcast_signature.result[0]) self.assertEqual(placements.SERVER, broadcast_signature.result[0].placement) self.assertEqual(value_type, broadcast_signature.result[1].member) self.assertEqual(placements.CLIENTS, broadcast_signature.result[1].placement)
def test_build_encoded_broadcast_process_raises_bad_encoder( self, bad_encoder): value_type = computation_types.TensorType(tf.float32, shape=[2]) with self.assertRaises(TypeError): encoding_utils.build_encoded_broadcast_process( value_type, bad_encoder)