示例#1
0
def _mean_encoder_fn(value):
    """Function for building encoded mean.

  This method decides, based on the tensor size, whether to use lossy
  compression or keep it as is (use identity encoder). The motivation for this
  pattern is due to the fact that compression of small model weights can provide
  only negligible benefit, while at the same time, lossy compression of small
  weights usually results in larger impact on model's accuracy.

  Args:
    value: A tensor or variable to be encoded in client to server communication.

  Returns:
    A `te.core.GatherEncoder`.
  """
    # TODO(b/131681951): We cannot use .from_tensor(...) because it does not
    # currently support Variables.
    spec = tf.TensorSpec(value.shape, value.dtype)
    if value.shape.num_elements() > 10000:
        if FLAGS.use_sparsity_in_aggregation:
            return te.encoders.as_gather_encoder(
                sparsity.sparse_quantizing_encoder(
                    FLAGS.aggregation_quantization_bits), spec)
        else:
            return te.encoders.as_gather_encoder(
                te.encoders.uniform_quantization(
                    FLAGS.aggregation_quantization_bits), spec)
    else:
        return te.encoders.as_gather_encoder(te.encoders.identity(), spec)
示例#2
0
def _mean_encoder_fn(spec):
  """Function for building encoded mean.

  This method decides, based on the tensor size, whether to use lossy
  compression or keep it as is (use identity encoder). The motivation for this
  pattern is due to the fact that compression of small model weights can provide
  only negligible benefit, while at the same time, lossy compression of small
  weights usually results in larger impact on model's accuracy.

  Args:
    spec: A `tf.TensorSpec` for the value to be encoded in client to server
      communication.

  Returns:
    A `te.core.GatherEncoder`.
  """
  if spec.shape.num_elements() > 10000:
    if FLAGS.use_sparsity_in_aggregation:
      return te.encoders.as_gather_encoder(
          sparsity.sparse_quantizing_encoder(
              FLAGS.aggregation_quantization_bits), spec)
    else:
      return te.encoders.as_gather_encoder(
          te.encoders.uniform_quantization(FLAGS.aggregation_quantization_bits),
          spec)
  else:
    return te.encoders.as_gather_encoder(te.encoders.identity(), spec)
示例#3
0
 def test_sparse_quantizing_encoder(self):
     encoder = sparsity.sparse_quantizing_encoder(8)
     self.assertIsInstance(encoder, te.core.Encoder)