Beispiel #1
0
 def test_clip_by_global_norm(self):
   updates = self.per_step_updates
   for i in range(1, STEPS + 1):
     clipper = transform.clip_by_global_norm(1. / i)
     updates, _ = clipper.update(updates, None)
     # Check that the clipper actually works and global norm is <= max_norm
     self.assertAlmostEqual(transform.global_norm(updates), 1. / i, places=6)
     updates_step, _ = clipper.update(self.per_step_updates, None)
     # Check that continuously clipping won't cause numerical issues.
     chex.assert_tree_all_close(updates, updates_step, atol=1e-7, rtol=1e-7)
Beispiel #2
0
    def test_scale_by_fromage(self):
        schedule = lambda c: 1.0 / (c + 1.0)
        fromage = transform.scale_by_fromage(step_size_factor_fn=schedule)
        params = self.init_params
        state = fromage.init(params)
        previous_norm = jnp.inf
        transform_fn = self.variant(fromage.update)

        for _ in range(STEPS):
            # Apply a step of fromage
            updates, state = transform_fn(self.per_step_updates, state, params)
            # Updates should get smaller due to the the learning schedule.
            norm = transform.global_norm(updates)
            self.assertLess(norm, previous_norm)
            previous_norm = norm