def test_optstate_sumsq(self): """Test that optstate sumsq and sumsq are computed correctly.""" init_fn, update_fn, _ = make_training_metrics( self.num_train_steps, optstate_sumsq_fields=['nu'], optstate_sum_fields=['nu']) initial_metrics_state = init_fn(self.mock_params0) self.assertTrue( pytree_equal(initial_metrics_state['optstate_sumsq'], {'nu': jnp.zeros(self.num_train_steps)})) self.assertTrue( pytree_equal(initial_metrics_state['optstate_sum'], {'nu': jnp.zeros(self.num_train_steps)})) updated_metrics_state = update_fn(initial_metrics_state, 0, self.mock_cost0, self.mock_grad1, self.mock_params0, self.mock_params1, self.mock_optimizer_state0) updated_metrics_state = update_fn(updated_metrics_state, 1, self.mock_cost1, self.mock_grad2, self.mock_params1, self.mock_params2, self.mock_optimizer_state1) self.assertEqual(updated_metrics_state['optstate_sumsq']['nu'][0], total_tree_norm_sql2(self.mock_nu0)) self.assertEqual(updated_metrics_state['optstate_sumsq']['nu'][1], total_tree_norm_sql2(self.mock_nu1)) self.assertEqual(updated_metrics_state['optstate_sum']['nu'][0], total_tree_sum(self.mock_nu0)) self.assertEqual(updated_metrics_state['optstate_sum']['nu'][1], total_tree_sum(self.mock_nu1))
def test_update_param_norm(self): """Ensure that the training metrics updater updates param norm correctly.""" init_fn, update_fn, _ = make_training_metrics(self.num_train_steps) initial_metrics_state = init_fn(self.mock_params0) updated_metrics_state = update_fn(initial_metrics_state, 0, self.mock_cost0, self.mock_grad1, self.mock_params0, self.mock_params1, self.mock_optimizer_state0) self.assertTrue( pytree_equal( updated_metrics_state['param_norm'], { 'foo': jnp.linalg.norm(self.mock_params0['foo']), 'bar': { 'baz': jnp.linalg.norm(self.mock_params0['bar']['baz']) } })) updated_metrics_state = update_fn(initial_metrics_state, 1, self.mock_cost1, self.mock_grad2, self.mock_params1, self.mock_params2, self.mock_optimizer_state1) self.assertTrue( pytree_equal( updated_metrics_state['param_norm'], { 'foo': jnp.linalg.norm(self.mock_params1['foo']), 'bar': { 'baz': jnp.linalg.norm(self.mock_params1['bar']['baz']) } }))
def test_summarize(self): """Test the training metrics summarizer.""" _, _, summarize_fn = make_training_metrics(self.num_train_steps, enable_train_cost=True, enable_ema=True) metrics_state = { 'train_cost': jnp.array([1.0, 0.5, 0.25, 0.0, 0.0]), 'param_norm': { 'foo': 7.0, 'bar': { 'baz': 2.0 } }, 'grad_ema': { 'foo': 1 * jnp.ones(5), 'bar': { 'baz': 2 * jnp.ones(10) } }, 'grad_sq_ema': { 'foo': 2 * jnp.ones(5), 'bar': { 'baz': 6 * jnp.ones(10) } }, 'update_ema': { 'foo': 2 * jnp.ones(5), 'bar': { 'baz': 1 * jnp.ones(10) } }, 'update_sq_ema': { 'foo': 6 * jnp.ones(5), 'bar': { 'baz': 2 * jnp.ones(10) } }, } tree_summary = summarize_fn(metrics_state) self.assertTrue( pytree_equal( tree_summary, { 'param_norm': { '/foo': 7.0, '/bar/baz': 2.0 }, 'grad_var': { '/foo': 5 * (2 - 1**2), '/bar/baz': 10 * (6 - 2**2) }, 'update_var': { '/foo': 5 * (6 - 2**2), '/bar/baz': 10 * (2 - 1**2) }, 'update_ratio': { '/foo': 5 * (6 - 2**2) / 7.0, '/bar/baz': 10 * (2 - 1**2) / 2.0 } }))
def test_init(self): """Test the training metrics initializer.""" zeros_like_params = jax.tree_map(jnp.zeros_like, self.mock_params0) zeros_scalar_like_params = jax.tree_map(lambda x: 0.0, self.mock_params0) zeros_timeseries = jnp.zeros(self.num_train_steps) zeros_timeseries_like_params = jax.tree_map( lambda x: jnp.zeros(self.num_train_steps), self.mock_params0) # Test init with everything disabled. init_fn, _, _ = make_training_metrics(self.num_train_steps) initial_metrics_state = init_fn(self.mock_params0) self.assertTrue( pytree_equal({'param_norm': zeros_scalar_like_params}, initial_metrics_state)) # Test init with enable_ema = True and enable_train_cost=True. init_fn, _, _ = make_training_metrics(self.num_train_steps, enable_ema=True, enable_train_cost=True, enable_param_norms=True, enable_gradient_norm=True, enable_update_norm=True, enable_update_norms=True) initial_metrics_state = init_fn(self.mock_params0) self.assertTrue( pytree_equal( initial_metrics_state, { 'train_cost': zeros_timeseries, 'param_norm': zeros_scalar_like_params, 'grad_ema': zeros_like_params, 'grad_sq_ema': zeros_like_params, 'update_ema': zeros_like_params, 'update_sq_ema': zeros_like_params, 'param_norms': zeros_timeseries_like_params, 'gradient_norm': zeros_timeseries, 'update_norm': zeros_timeseries, 'update_norms': zeros_timeseries_like_params }))
def test_save_load_roundtrip(self): """Test that saving and loading produces the original state.""" baz = ['a', 'b', 'ccc'] state = dict(params=self.params, global_step=5, completed_epochs=4, baz=baz) checkpoint.save_checkpoint(self.test_dir, 0, state) latest = checkpoint.load_latest_checkpoint(self.test_dir, target=state) self.assertEqual(latest['baz'], baz) assert pytree_equal(latest['params'], self.params) self.assertEqual(latest['global_step'], 5) self.assertEqual(latest['completed_epochs'], 4)
def test_update_update_norms(self): """Ensure that we update gradient and update norms correctly.""" init_fn, update_fn, _ = make_training_metrics( self.num_train_steps, enable_gradient_norm=True, enable_update_norm=True, enable_update_norms=True) initial_metrics_state = init_fn(self.mock_params0) updated_metrics_state = update_fn(initial_metrics_state, 0, self.mock_cost0, self.mock_grad1, self.mock_params0, self.mock_params1, self.mock_optimizer_state0) updated_metrics_state = update_fn(updated_metrics_state, 1, self.mock_cost1, self.mock_grad2, self.mock_params1, self.mock_params2, self.mock_optimizer_state1) self.assertTrue( pytree_equal( updated_metrics_state['update_norms'], { 'foo': jnp.array([ self.step_size * jnp.linalg.norm(self.mock_grad1['foo']), self.step_size * jnp.linalg.norm(self.mock_grad2['foo']), 0.0, 0.0, 0.0 ]), 'bar': { 'baz': jnp.array([ self.step_size * jnp.linalg.norm(self.mock_grad1['bar']['baz']), self.step_size * jnp.linalg.norm(self.mock_grad2['bar']['baz']), 0.0, 0.0, 0.0 ]) } })) self.assertEqual(updated_metrics_state['update_norm'][0], total_tree_norm_l2(self.mock_grad1)) self.assertEqual(updated_metrics_state['update_norm'][1], total_tree_norm_l2(self.mock_grad2)) self.assertEqual(updated_metrics_state['update_norm'][0], self.step_size * total_tree_norm_l2(self.mock_grad1)) self.assertEqual(updated_metrics_state['update_norm'][1], self.step_size * total_tree_norm_l2(self.mock_grad2))
def test_train_cost(self): """Ensure that the train cost is logged correctly.""" init_fn, update_fn, _ = make_training_metrics(self.num_train_steps, enable_train_cost=True) initial_metrics_state = init_fn(self.mock_params0) updated_metrics_state = update_fn(initial_metrics_state, 0, self.mock_cost0, self.mock_grad1, self.mock_params0, self.mock_params1, self.mock_optimizer_state0) updated_metrics_state = update_fn(updated_metrics_state, 1, self.mock_cost1, self.mock_grad2, self.mock_params1, self.mock_params2, self.mock_optimizer_state1) self.assertTrue( pytree_equal( updated_metrics_state['train_cost'], jnp.array([self.mock_cost0, self.mock_cost1, 0.0, 0.0, 0.0])))
def test_update_grad_ema(self): """Ensure that the training metrics updater updates grad ema correctly.""" init_fn, update_fn, _ = make_training_metrics(self.num_train_steps, enable_ema=True, ema_beta=0.5) initial_metrics_state = init_fn(self.mock_params0) updated_metrics_state = update_fn(initial_metrics_state, 0, self.mock_cost0, self.mock_grad1, self.mock_params0, self.mock_params1, self.mock_optimizer_state0) updated_metrics_state = update_fn(updated_metrics_state, 1, self.mock_cost1, self.mock_grad2, self.mock_params1, self.mock_params2, self.mock_optimizer_state1) self.assertTrue( pytree_equal( updated_metrics_state['grad_ema'], jax.tree_map(lambda x, y, z: 0.25 * x + 0.25 * y + 0.5 * z, self.mock_zeros, self.mock_grad1, self.mock_grad2)))
def test_replicate_and_maybe_restore_from_checkpoint_logic(self): """Test that the right checkpoint is returned. 1. If no external_checkpoint_path was passed, and if there is no latest checkpoint in the train_dir, then the function should return the passed-in params, batch_stats, etc. 2. If an external checkpoint was provided but no latest checkpoint exists in the train_dir, then the function should return the external checkpoint. 3. If a latest checkpoint exists in the train dir, then the function should return that checkpoint. In the interest of conciseness, this test only checks the params, not the batch_stats, optimizer_state, or training_metics. The below test test_all_variables_restored() covers the other three. """ # mock parameters. initial_params = {'foo': 1.0} latest_params = {'foo': 2.0} external_params = {'foo': 3.0} fresh_train_dir = tempfile.mkdtemp() external_dir = tempfile.mkdtemp() # two helper functions def save_checkpoint(train_dir, global_step, preemption_count, sum_train_cost, params): """Helper function to save a checkpoint.""" checkpoint.save_checkpoint(train_dir=train_dir, step=global_step, state=dict( global_step=global_step, preemption_count=preemption_count, sum_train_cost=sum_train_cost, optimizer_state={}, params=params, batch_stats={}, training_metrics_grabber={}), max_to_keep=1) def maybe_restore_checkpoint(params, train_dir, external_checkpoint_path): """Helper function to replicate_and_maybe_restore a checkpoint.""" (_, ret_params, _, _, ret_global_step, ret_sum_train_cost, ret_preemption_count, ret_is_restored ) = checkpoint.replicate_and_maybe_restore_checkpoint( {}, params, {}, {}, train_dir, external_checkpoint_path) ret_params_unrep = jax.device_get( jax_utils.unreplicate(ret_params)) return (ret_params_unrep, ret_global_step, ret_sum_train_cost, ret_preemption_count, ret_is_restored) # Save external checkpoint. save_checkpoint(train_dir=external_dir, global_step=5, preemption_count=4, sum_train_cost=7.0, params=external_params) external_checkpoint_path = os.path.join(external_dir, 'ckpt_' + str(5)) # If no latest checkpoint exists, and no external checkpoint was provided, # the function should return the passed-in params. (ret_params, ret_global_step, ret_sum_train_cost, ret_preemption_count, ret_is_restored) = maybe_restore_checkpoint(initial_params, fresh_train_dir, None) self.assertEqual(ret_preemption_count, 0) self.assertEqual(ret_global_step, 0) self.assertEqual(ret_sum_train_cost, 0.0) self.assertFalse(ret_is_restored) assert pytree_equal(ret_params, initial_params) # If no latest checkpoint exists, and an external checkpoint was provided, # the function should return the external checkpoint. (ret_params, ret_global_step, ret_sum_train_cost, ret_preemption_count, ret_is_restored) = maybe_restore_checkpoint(initial_params, fresh_train_dir, external_checkpoint_path) self.assertEqual(ret_preemption_count, 4) self.assertEqual(ret_global_step, 5) self.assertEqual(ret_sum_train_cost, 7.0) self.assertFalse(ret_is_restored) assert pytree_equal(ret_params, external_params) # Save latest checkpoint. save_checkpoint(train_dir=fresh_train_dir, global_step=10, preemption_count=2, sum_train_cost=2.2, params=latest_params) # If a latest checkpoint exists, then even if an external checkpoint was # provided, the function should return the latest checkpoint. (ret_params, ret_global_step, ret_sum_train_cost, ret_preemption_count, ret_is_restored) = maybe_restore_checkpoint(initial_params, fresh_train_dir, external_checkpoint_path) self.assertEqual(ret_preemption_count, 2) self.assertEqual(ret_global_step, 10) self.assertEqual(ret_sum_train_cost, 2.2) self.assertTrue(ret_is_restored) assert pytree_equal(ret_params, latest_params) shutil.rmtree(fresh_train_dir) shutil.rmtree(external_dir)
def test_all_variables_restored(self): """Test that all variables are properly restored. This test checks that optimizer_state, params, batch_stats, and training_metrics_grabber are all properly restored after training is pre-empted. """ fresh_train_dir = tempfile.mkdtemp() global_step = 100 preemption_count = 8 sum_train_cost = 0.9 saved_optimizer_state = {'second_moments': 7} saved_params = {'kernel': 3} saved_batch_stats = {'mean': 2} saved_training_metrics = {'ema': 4} initial_optimizer_state = {'second_moments': 0} initial_params = {'kernel': 0} initial_batch_stats = {'mean': 0} initial_training_metrics = {'ema': 0} checkpoint.save_checkpoint( train_dir=fresh_train_dir, step=global_step, state=dict(global_step=global_step, preemption_count=preemption_count, sum_train_cost=sum_train_cost, optimizer_state=saved_optimizer_state, params=saved_params, batch_stats=saved_batch_stats, training_metrics_grabber=saved_training_metrics), max_to_keep=1) ( ret_state, ret_params, ret_batch_stats, ret_training_metrics, ret_global_step, ret_sum_train_cost, ret_preemption_count, ret_is_restored, ) = checkpoint.replicate_and_maybe_restore_checkpoint( initial_optimizer_state, initial_params, initial_batch_stats, initial_training_metrics, fresh_train_dir) assert pytree_equal(jax.device_get(jax_utils.unreplicate(ret_state)), saved_optimizer_state) assert pytree_equal(jax.device_get(jax_utils.unreplicate(ret_params)), saved_params) assert pytree_equal( jax.device_get(jax_utils.unreplicate(ret_batch_stats)), saved_batch_stats) assert pytree_equal( jax.device_get(jax_utils.unreplicate(ret_training_metrics)), saved_training_metrics) self.assertEqual(ret_sum_train_cost, sum_train_cost) self.assertEqual(ret_preemption_count, preemption_count) self.assertEqual(ret_global_step, global_step) self.assertEqual(ret_is_restored, True) shutil.rmtree(fresh_train_dir)