Exemplo n.º 1
0
 def test_safe_div(self):
   # Safe division by zero.
   npt.assert_array_equal(
       util.safe_div(jnp.array([1, 2, 3]), jnp.array([0, 1, 2])), [0, 2, 1.5])
   # Safe gradient when division by zero.
   grad = jax.grad(lambda xy: jnp.sum(util.safe_div(xy[0], xy[1])))
   npt.assert_array_equal(
       grad(jnp.array([[1, 2, 3], [0, 1, 2]], dtype=jnp.float32)),
       [[0, 1, 0.5], [0, -2, -3 / 4]])
Exemplo n.º 2
0
 def scalar_loss(params, batch_example, rng):
     batch_loss = per_example_loss(params, batch_example, rng)
     if client_datasets.EXAMPLE_MASK_KEY in batch_example:
         mask = batch_example[client_datasets.EXAMPLE_MASK_KEY]
         num_examples = jnp.sum(mask)
         loss = util.safe_div(jnp.vdot(batch_loss, mask), num_examples)
     else:
         loss = jnp.mean(batch_loss)
     if regularizer is not None:
         loss += regularizer(params)
     return loss
Exemplo n.º 3
0
 def server_update(server_state, mean_delta_params, sum_domain_loss,
                   sum_domain_num):
     opt_state, params = server_optimizer.apply(mean_delta_params,
                                                server_state.opt_state,
                                                server_state.params)
     mean_domain_loss = util.safe_div(sum_domain_loss, sum_domain_num)
     domain_weights = update_domain_weights(server_state.domain_weights,
                                            mean_domain_loss,
                                            domain_learning_rate,
                                            domain_algorithm)
     domain_window = server_state.domain_window[1:] + [sum_domain_num]
     return ServerState(params, opt_state, domain_weights, domain_window)
Exemplo n.º 4
0
 def result(self) -> jnp.ndarray:
     return util.safe_div(self.accum, self.weight)
Exemplo n.º 5
0
def _finalize_average_loss(regularizer, params, accum_loss, num_examples):
    average_loss = util.safe_div(accum_loss, num_examples)
    if regularizer is not None:
        average_loss += regularizer(params)
    return average_loss