コード例 #1
0
 def test_build_encoded_mean_process_raises_incompatible_encoder(self):
     value_type = computation_types.TensorType(tf.float32, shape=[2])
     incompatible_encoder = te.encoders.as_gather_encoder(
         te.encoders.identity(), tf.TensorSpec((3, )))
     with self.assertRaises(TypeError):
         encoding_utils.build_encoded_mean_process(value_type,
                                                   incompatible_encoder)
コード例 #2
0
 def test_build_encoded_mean_process_raises_bad_structure(self):
     value_type = computation_types.StructType([
         computation_types.TensorType(tf.float32, shape=[2]),
         computation_types.TensorType(tf.float32, shape=[2])
     ])
     encoder = te.encoders.as_gather_encoder(te.encoders.identity(),
                                             tf.TensorSpec((2, )))
     with self.assertRaises(ValueError):
         encoding_utils.build_encoded_mean_process(value_type, encoder)
コード例 #3
0
def build_encoded_mean_process_from_model(
        model_fn: _ModelConstructor,
        encoder_fn: _EncoderConstructor) -> measured_process.MeasuredProcess:
    """Builds `MeasuredProcess` for weights of model returned by `model_fn`.

  This method creates a `GatherEncoder` for every trainable weight of model
  created by `model_fn`, as returned by `encoder_fn`.

  Args:
    model_fn: A Python callable with no arguments function that returns a
      `tff.learning.Model`.
    encoder_fn: A Python callable with a single argument, which is expected to
      be a `tf.Tensor` of shape and dtype to be encoded. The function must
      return a `tensor_encoding.core.SimpleEncoder`, which expects a `tf.Tensor`
      with compatible type as the input to its `encode` method.

  Returns:
    A `MeasuredProcess` for encoding and averaging the weights of model created
    by `model_fn`.

  Raises:
    TypeError: If `model_fn` or `encoder_fn` are not callable objects.
  """
    py_typecheck.check_callable(model_fn)
    py_typecheck.check_callable(encoder_fn)
    trainable_weights = _weights_from_model_fn(model_fn).trainable
    encoders = tf.nest.map_structure(encoder_fn, trainable_weights)
    weight_type = type_conversions.type_from_tensors(trainable_weights)
    return encoding_utils.build_encoded_mean_process(weight_type, encoders)
コード例 #4
0
  def test_build_encoded_mean_process(self, value_constructor,
                                      encoder_constructor):
    value = value_constructor(np.random.rand(20))
    value_spec = tf.TensorSpec(value.shape, tf.dtypes.as_dtype(value.dtype))
    value_type = computation_types.to_type(value_spec)
    encoder = te.encoders.as_gather_encoder(encoder_constructor(), value_spec)
    gather_process = encoding_utils.build_encoded_mean_process(
        value_type, encoder)
    state_type = gather_process._initialize_fn.type_signature.result
    gather_signature = gather_process._next_fn.type_signature

    self.assertIsInstance(gather_process, MeasuredProcess)
    self.assertEqual(state_type, gather_signature.result[0])
    self.assertEqual(placements.SERVER, gather_signature.result[0].placement)
    self.assertEqual(value_type, gather_signature.result[1].member)
    self.assertEqual(placements.SERVER, gather_signature.result[1].placement)
コード例 #5
0
  def test_run_encoded_mean_process(self):
    value = np.array([0.0, 1.0, 2.0, -1.0])
    value_spec = tf.TensorSpec(value.shape, tf.dtypes.as_dtype(value.dtype))
    encoder = te.encoders.as_gather_encoder(te.encoders.identity(), value_spec)
    value_type = type_conversions.type_from_tensors(value)
    gather_process = encoding_utils.build_encoded_mean_process(
        value_type, encoder)
    initial_state = gather_process.initialize()
    call_gather = gather_process._next_fn

    output = call_gather(initial_state, [value, value], [1.0, 1.0])
    self.assertAllClose(1 * value, output['result'])

    output = call_gather(initial_state, [value, value], [0.3, 0.7])
    self.assertAllClose(1 * value, output['result'])

    output = call_gather(initial_state, [value, 2 * value], [1.0, 2.0])
    self.assertAllClose(5 / 3 * value, output['result'])
コード例 #6
0
 def test_build_encoded_mean_process_raises_bad_encoder(self, bad_encoder):
     value_type = computation_types.TensorType(tf.float32, shape=[2])
     with self.assertRaises(TypeError):
         encoding_utils.build_encoded_mean_process(value_type, bad_encoder)