def test_safe_div(self): # Safe division by zero. npt.assert_array_equal( util.safe_div(jnp.array([1, 2, 3]), jnp.array([0, 1, 2])), [0, 2, 1.5]) # Safe gradient when division by zero. grad = jax.grad(lambda xy: jnp.sum(util.safe_div(xy[0], xy[1]))) npt.assert_array_equal( grad(jnp.array([[1, 2, 3], [0, 1, 2]], dtype=jnp.float32)), [[0, 1, 0.5], [0, -2, -3 / 4]])
def scalar_loss(params, batch_example, rng): batch_loss = per_example_loss(params, batch_example, rng) if client_datasets.EXAMPLE_MASK_KEY in batch_example: mask = batch_example[client_datasets.EXAMPLE_MASK_KEY] num_examples = jnp.sum(mask) loss = util.safe_div(jnp.vdot(batch_loss, mask), num_examples) else: loss = jnp.mean(batch_loss) if regularizer is not None: loss += regularizer(params) return loss
def server_update(server_state, mean_delta_params, sum_domain_loss, sum_domain_num): opt_state, params = server_optimizer.apply(mean_delta_params, server_state.opt_state, server_state.params) mean_domain_loss = util.safe_div(sum_domain_loss, sum_domain_num) domain_weights = update_domain_weights(server_state.domain_weights, mean_domain_loss, domain_learning_rate, domain_algorithm) domain_window = server_state.domain_window[1:] + [sum_domain_num] return ServerState(params, opt_state, domain_weights, domain_window)
def result(self) -> jnp.ndarray: return util.safe_div(self.accum, self.weight)
def _finalize_average_loss(regularizer, params, accum_loss, num_examples): average_loss = util.safe_div(accum_loss, num_examples) if regularizer is not None: average_loss += regularizer(params) return average_loss