Exemplo n.º 1
0
  def apply(
      clients_params_and_weights: Iterable[Tuple[ClientId, Params, float]],
      aggregator_state: CompressionState) -> Tuple[Params, CompressionState]:

    rng, rotation_rng = jax.random.split(aggregator_state.rng)
    rotation_rng_seq = hk.PRNGSequence(rotation_rng)
    clients_params_and_weight_rng = zip(clients_params_and_weights,
                                        rotation_rng_seq)
    def quantize_params_and_weight(client_params_and_weight, client_rng):
      _, params, weight = client_params_and_weight
      rotated_param, shapes = walsh_hadamard.structured_rotation_pytree(
          params, client_rng)
      return walsh_hadamard.inverse_structured_rotation_pytree(
          drive_pytree(rotated_param), client_rng, shapes), weight

    quantized_p_and_w = itertools.starmap(quantize_params_and_weight,
                                          clients_params_and_weight_rng)

    aggregated_params = tree_util.tree_mean(quantized_p_and_w)

    total_num_params = tree_util.tree_size(aggregated_params)
    total_num_floats = 2 * num_leaves(aggregated_params)
    # 32 bits for every float used and one bit for every parameter.
    new_bits = total_num_params + 32 * total_num_floats
    new_state = CompressionState(aggregator_state.num_bits + new_bits, rng)
    return aggregated_params, new_state
Exemplo n.º 2
0
  def apply(
      clients_params_and_weights: Iterable[Tuple[ClientId, Params, float]],
      aggregator_state: CompressionState) -> Tuple[Params, CompressionState]:

    rng, rotation_rng = jax.random.split(aggregator_state.rng)

    def quantize_params_and_weight(client_params_and_weight, rng):
      _, params, weight = client_params_and_weight
      params, shapes = walsh_hadamard.structured_rotation_pytree(
          params, rotation_rng)
      params = walsh_hadamard.inverse_structured_rotation_pytree(
          uniform_stochastic_quantize_pytree(params, num_levels, rng),
          rotation_rng, shapes)
      return params, weight

    rng, use_rng = jax.random.split(rng)
    # TODO(theertha): remove the usage of hk.PRNGSequence.
    rng_seq = hk.PRNGSequence(use_rng)
    clients_params_and_weight_rng = zip(clients_params_and_weights, rng_seq)
    quantized_p_and_w = itertools.starmap(quantize_params_and_weight,
                                          clients_params_and_weight_rng)
    aggregated_params = tree_util.tree_mean(quantized_p_and_w)
    total_num_params = tree_util.tree_size(aggregated_params)
    total_num_floats = 2 * num_leaves(aggregated_params)
    # 32 bits for every float used and log2(num_levels) bit for every parameter.
    new_bits = math.log2(num_levels) * total_num_params + 32 * total_num_floats
    new_state = CompressionState(aggregator_state.num_bits + new_bits, rng)
    return aggregated_params, new_state
Exemplo n.º 3
0
 def test_create_regression_model(self):
     model = toy_regression.create_regression_model()
     params = model.init(jax.random.PRNGKey(0))
     batch = {'x': jnp.ones((5, 1)), 'y': jnp.ones((5, ))}
     self.assertEqual(tree_util.tree_size(params), 1)
     with self.subTest('apply_for_train'):
         preds = model.apply_for_train(params, batch)
         self.assertTupleEqual(preds.shape, ())
     with self.subTest('apply_for_eval'):
         preds = model.apply_for_eval(params, batch)
         self.assertTupleEqual(preds.shape, ())
     with self.subTest('train_loss'):
         preds = model.apply_for_train(params, batch)
         train_loss = model.train_loss(batch, preds)
         self.assertTupleEqual(train_loss.shape, ())
Exemplo n.º 4
0
  def apply(
      clients_params_and_weights: Iterable[Tuple[ClientId, Params, float]],
      aggregator_state: CompressionState) -> Tuple[Params, CompressionState]:

    if encode_algorithm is not None:
      assert encode_algorithm == 'arithmetic'

    def quantize_params_and_weight(client_params_and_weight, rng):
      _, params, weight = client_params_and_weight
      return uniform_stochastic_quantize_pytree(params, num_levels, rng), weight

    rng, use_rng = jax.random.split(aggregator_state.rng)
    # TODO(theertha): remove the usage of hk.PRNGSequence.
    rng_seq = hk.PRNGSequence(use_rng)
    clients_params_and_weight_rng = zip(clients_params_and_weights, rng_seq)
    quantized_p_and_w = itertools.starmap(quantize_params_and_weight,
                                          clients_params_and_weight_rng)
    new_bits = 0.
    if encode_algorithm == 'arithmetic':
      # Accumulate the number of bits used by all clients without loading the
      # entire iterator into memory at once.
      total_bits = []

      def arithmetic_encoding_num_bits_pytree(params, weights):
        leaves, _ = jax.tree_util.tree_flatten(params)
        bits = sum([arithmetic_encoding_num_bits(leaf) for leaf in leaves])
        total_bits.append(bits)
        return params, weights

      quantized_p_and_w = itertools.starmap(arithmetic_encoding_num_bits_pytree,
                                            quantized_p_and_w)
      aggregated_params = tree_util.tree_mean(quantized_p_and_w)
      new_bits = sum(total_bits) / len(total_bits) if len(total_bits) else 0.

    else:
      aggregated_params = tree_util.tree_mean(quantized_p_and_w)
      total_num_params = tree_util.tree_size(aggregated_params)
      total_num_floats = 2 * num_leaves(aggregated_params)
      # 32 bits for every float and log2(num_levels) bit for every parameter.
      new_bits = math.log2(
          num_levels) * total_num_params + 32 * total_num_floats
    new_state = CompressionState(aggregator_state.num_bits + new_bits, rng)
    return aggregated_params, new_state
Exemplo n.º 5
0
  def apply(
      clients_params_and_weights: Iterable[Tuple[ClientId, Params, float]],
      aggregator_state: CompressionState) -> Tuple[Params, CompressionState]:

    def quantize_params_and_weight(client_params_and_weight, rng):
      _, params, weight = client_params_and_weight
      return terngrad_quantize_pytree(params, rng), weight

    rng, use_rng = jax.random.split(aggregator_state.rng)
    # TODO(theertha): remove the usage of hk.PRNGSequence.
    rng_seq = hk.PRNGSequence(use_rng)
    clients_params_and_weight_rng = zip(clients_params_and_weights, rng_seq)
    quantized_p_and_w = itertools.starmap(quantize_params_and_weight,
                                          clients_params_and_weight_rng)
    aggregated_params = tree_util.tree_mean(quantized_p_and_w)
    total_num_params = tree_util.tree_size(aggregated_params)
    total_num_floats = 2 * num_leaves(aggregated_params)
    # 32 bits for every float used and log2(3) bit for every parameter.
    new_bits = math.log2(3) * total_num_params + 32 * total_num_floats
    new_state = CompressionState(aggregator_state.num_bits + new_bits, rng)
    return aggregated_params, new_state
Exemplo n.º 6
0
 def test_create_lstm_model_share_embeddings(self):
     model = stackoverflow.create_lstm_model(
         share_input_output_embeddings=True)
     params = model.init(jax.random.PRNGKey(0))
     self.assertEqual(tree_util.tree_size(params), 3090364)
     self.check_model(model)
Exemplo n.º 7
0
 def test_create_lstm_model(self):
     model = stackoverflow.create_lstm_model()
     params = model.init(jax.random.PRNGKey(0))
     self.assertEqual(tree_util.tree_size(params), 4050748)
     self.check_model(model)
Exemplo n.º 8
0
 def test_create_logistic_model(self):
     model = cifar100.create_logistic_model()
     params = model.init(jax.random.PRNGKey(0))
     self.assertEqual(tree_util.tree_size(params), 172900)
     self.check_model(model)
Exemplo n.º 9
0
 def test_create_stax_dense_model(self):
     model = emnist.create_stax_dense_model(only_digits=False,
                                            hidden_units=200)
     params = model.init(jax.random.PRNGKey(0))
     self.assertEqual(tree_util.tree_size(params), 209662)
     self.check_model(model)
Exemplo n.º 10
0
 def test_create_logistic_model(self):
     model = emnist.create_logistic_model(only_digits=False)
     params = model.init(jax.random.PRNGKey(0))
     self.assertEqual(tree_util.tree_size(params), 48670)
     self.check_model(model)