def apply( clients_params_and_weights: Iterable[Tuple[ClientId, Params, float]], aggregator_state: CompressionState) -> Tuple[Params, CompressionState]: rng, rotation_rng = jax.random.split(aggregator_state.rng) rotation_rng_seq = hk.PRNGSequence(rotation_rng) clients_params_and_weight_rng = zip(clients_params_and_weights, rotation_rng_seq) def quantize_params_and_weight(client_params_and_weight, client_rng): _, params, weight = client_params_and_weight rotated_param, shapes = walsh_hadamard.structured_rotation_pytree( params, client_rng) return walsh_hadamard.inverse_structured_rotation_pytree( drive_pytree(rotated_param), client_rng, shapes), weight quantized_p_and_w = itertools.starmap(quantize_params_and_weight, clients_params_and_weight_rng) aggregated_params = tree_util.tree_mean(quantized_p_and_w) total_num_params = tree_util.tree_size(aggregated_params) total_num_floats = 2 * num_leaves(aggregated_params) # 32 bits for every float used and one bit for every parameter. new_bits = total_num_params + 32 * total_num_floats new_state = CompressionState(aggregator_state.num_bits + new_bits, rng) return aggregated_params, new_state
def apply( clients_params_and_weights: Iterable[Tuple[ClientId, Params, float]], aggregator_state: CompressionState) -> Tuple[Params, CompressionState]: rng, rotation_rng = jax.random.split(aggregator_state.rng) def quantize_params_and_weight(client_params_and_weight, rng): _, params, weight = client_params_and_weight params, shapes = walsh_hadamard.structured_rotation_pytree( params, rotation_rng) params = walsh_hadamard.inverse_structured_rotation_pytree( uniform_stochastic_quantize_pytree(params, num_levels, rng), rotation_rng, shapes) return params, weight rng, use_rng = jax.random.split(rng) # TODO(theertha): remove the usage of hk.PRNGSequence. rng_seq = hk.PRNGSequence(use_rng) clients_params_and_weight_rng = zip(clients_params_and_weights, rng_seq) quantized_p_and_w = itertools.starmap(quantize_params_and_weight, clients_params_and_weight_rng) aggregated_params = tree_util.tree_mean(quantized_p_and_w) total_num_params = tree_util.tree_size(aggregated_params) total_num_floats = 2 * num_leaves(aggregated_params) # 32 bits for every float used and log2(num_levels) bit for every parameter. new_bits = math.log2(num_levels) * total_num_params + 32 * total_num_floats new_state = CompressionState(aggregator_state.num_bits + new_bits, rng) return aggregated_params, new_state
def test_create_regression_model(self): model = toy_regression.create_regression_model() params = model.init(jax.random.PRNGKey(0)) batch = {'x': jnp.ones((5, 1)), 'y': jnp.ones((5, ))} self.assertEqual(tree_util.tree_size(params), 1) with self.subTest('apply_for_train'): preds = model.apply_for_train(params, batch) self.assertTupleEqual(preds.shape, ()) with self.subTest('apply_for_eval'): preds = model.apply_for_eval(params, batch) self.assertTupleEqual(preds.shape, ()) with self.subTest('train_loss'): preds = model.apply_for_train(params, batch) train_loss = model.train_loss(batch, preds) self.assertTupleEqual(train_loss.shape, ())
def apply( clients_params_and_weights: Iterable[Tuple[ClientId, Params, float]], aggregator_state: CompressionState) -> Tuple[Params, CompressionState]: if encode_algorithm is not None: assert encode_algorithm == 'arithmetic' def quantize_params_and_weight(client_params_and_weight, rng): _, params, weight = client_params_and_weight return uniform_stochastic_quantize_pytree(params, num_levels, rng), weight rng, use_rng = jax.random.split(aggregator_state.rng) # TODO(theertha): remove the usage of hk.PRNGSequence. rng_seq = hk.PRNGSequence(use_rng) clients_params_and_weight_rng = zip(clients_params_and_weights, rng_seq) quantized_p_and_w = itertools.starmap(quantize_params_and_weight, clients_params_and_weight_rng) new_bits = 0. if encode_algorithm == 'arithmetic': # Accumulate the number of bits used by all clients without loading the # entire iterator into memory at once. total_bits = [] def arithmetic_encoding_num_bits_pytree(params, weights): leaves, _ = jax.tree_util.tree_flatten(params) bits = sum([arithmetic_encoding_num_bits(leaf) for leaf in leaves]) total_bits.append(bits) return params, weights quantized_p_and_w = itertools.starmap(arithmetic_encoding_num_bits_pytree, quantized_p_and_w) aggregated_params = tree_util.tree_mean(quantized_p_and_w) new_bits = sum(total_bits) / len(total_bits) if len(total_bits) else 0. else: aggregated_params = tree_util.tree_mean(quantized_p_and_w) total_num_params = tree_util.tree_size(aggregated_params) total_num_floats = 2 * num_leaves(aggregated_params) # 32 bits for every float and log2(num_levels) bit for every parameter. new_bits = math.log2( num_levels) * total_num_params + 32 * total_num_floats new_state = CompressionState(aggregator_state.num_bits + new_bits, rng) return aggregated_params, new_state
def apply( clients_params_and_weights: Iterable[Tuple[ClientId, Params, float]], aggregator_state: CompressionState) -> Tuple[Params, CompressionState]: def quantize_params_and_weight(client_params_and_weight, rng): _, params, weight = client_params_and_weight return terngrad_quantize_pytree(params, rng), weight rng, use_rng = jax.random.split(aggregator_state.rng) # TODO(theertha): remove the usage of hk.PRNGSequence. rng_seq = hk.PRNGSequence(use_rng) clients_params_and_weight_rng = zip(clients_params_and_weights, rng_seq) quantized_p_and_w = itertools.starmap(quantize_params_and_weight, clients_params_and_weight_rng) aggregated_params = tree_util.tree_mean(quantized_p_and_w) total_num_params = tree_util.tree_size(aggregated_params) total_num_floats = 2 * num_leaves(aggregated_params) # 32 bits for every float used and log2(3) bit for every parameter. new_bits = math.log2(3) * total_num_params + 32 * total_num_floats new_state = CompressionState(aggregator_state.num_bits + new_bits, rng) return aggregated_params, new_state
def test_create_lstm_model_share_embeddings(self): model = stackoverflow.create_lstm_model( share_input_output_embeddings=True) params = model.init(jax.random.PRNGKey(0)) self.assertEqual(tree_util.tree_size(params), 3090364) self.check_model(model)
def test_create_lstm_model(self): model = stackoverflow.create_lstm_model() params = model.init(jax.random.PRNGKey(0)) self.assertEqual(tree_util.tree_size(params), 4050748) self.check_model(model)
def test_create_logistic_model(self): model = cifar100.create_logistic_model() params = model.init(jax.random.PRNGKey(0)) self.assertEqual(tree_util.tree_size(params), 172900) self.check_model(model)
def test_create_stax_dense_model(self): model = emnist.create_stax_dense_model(only_digits=False, hidden_units=200) params = model.init(jax.random.PRNGKey(0)) self.assertEqual(tree_util.tree_size(params), 209662) self.check_model(model)
def test_create_logistic_model(self): model = emnist.create_logistic_model(only_digits=False) params = model.init(jax.random.PRNGKey(0)) self.assertEqual(tree_util.tree_size(params), 48670) self.check_model(model)