コード例 #1
0
    def test_optstate_sumsq(self):
        """Test that optstate sumsq and sumsq are computed correctly."""
        init_fn, update_fn, _ = make_training_metrics(
            self.num_train_steps,
            optstate_sumsq_fields=['nu'],
            optstate_sum_fields=['nu'])
        initial_metrics_state = init_fn(self.mock_params0)
        self.assertTrue(
            pytree_equal(initial_metrics_state['optstate_sumsq'],
                         {'nu': jnp.zeros(self.num_train_steps)}))
        self.assertTrue(
            pytree_equal(initial_metrics_state['optstate_sum'],
                         {'nu': jnp.zeros(self.num_train_steps)}))
        updated_metrics_state = update_fn(initial_metrics_state, 0,
                                          self.mock_cost0, self.mock_grad1,
                                          self.mock_params0, self.mock_params1,
                                          self.mock_optimizer_state0)
        updated_metrics_state = update_fn(updated_metrics_state, 1,
                                          self.mock_cost1, self.mock_grad2,
                                          self.mock_params1, self.mock_params2,
                                          self.mock_optimizer_state1)

        self.assertEqual(updated_metrics_state['optstate_sumsq']['nu'][0],
                         total_tree_norm_sql2(self.mock_nu0))
        self.assertEqual(updated_metrics_state['optstate_sumsq']['nu'][1],
                         total_tree_norm_sql2(self.mock_nu1))

        self.assertEqual(updated_metrics_state['optstate_sum']['nu'][0],
                         total_tree_sum(self.mock_nu0))
        self.assertEqual(updated_metrics_state['optstate_sum']['nu'][1],
                         total_tree_sum(self.mock_nu1))
コード例 #2
0
    def test_update_param_norm(self):
        """Ensure that the training metrics updater updates param norm correctly."""

        init_fn, update_fn, _ = make_training_metrics(self.num_train_steps)
        initial_metrics_state = init_fn(self.mock_params0)
        updated_metrics_state = update_fn(initial_metrics_state, 0,
                                          self.mock_cost0, self.mock_grad1,
                                          self.mock_params0, self.mock_params1,
                                          self.mock_optimizer_state0)
        self.assertTrue(
            pytree_equal(
                updated_metrics_state['param_norm'], {
                    'foo': jnp.linalg.norm(self.mock_params0['foo']),
                    'bar': {
                        'baz': jnp.linalg.norm(self.mock_params0['bar']['baz'])
                    }
                }))

        updated_metrics_state = update_fn(initial_metrics_state, 1,
                                          self.mock_cost1, self.mock_grad2,
                                          self.mock_params1, self.mock_params2,
                                          self.mock_optimizer_state1)

        self.assertTrue(
            pytree_equal(
                updated_metrics_state['param_norm'], {
                    'foo': jnp.linalg.norm(self.mock_params1['foo']),
                    'bar': {
                        'baz': jnp.linalg.norm(self.mock_params1['bar']['baz'])
                    }
                }))
コード例 #3
0
 def test_summarize(self):
     """Test the training metrics summarizer."""
     _, _, summarize_fn = make_training_metrics(self.num_train_steps,
                                                enable_train_cost=True,
                                                enable_ema=True)
     metrics_state = {
         'train_cost': jnp.array([1.0, 0.5, 0.25, 0.0, 0.0]),
         'param_norm': {
             'foo': 7.0,
             'bar': {
                 'baz': 2.0
             }
         },
         'grad_ema': {
             'foo': 1 * jnp.ones(5),
             'bar': {
                 'baz': 2 * jnp.ones(10)
             }
         },
         'grad_sq_ema': {
             'foo': 2 * jnp.ones(5),
             'bar': {
                 'baz': 6 * jnp.ones(10)
             }
         },
         'update_ema': {
             'foo': 2 * jnp.ones(5),
             'bar': {
                 'baz': 1 * jnp.ones(10)
             }
         },
         'update_sq_ema': {
             'foo': 6 * jnp.ones(5),
             'bar': {
                 'baz': 2 * jnp.ones(10)
             }
         },
     }
     tree_summary = summarize_fn(metrics_state)
     self.assertTrue(
         pytree_equal(
             tree_summary, {
                 'param_norm': {
                     '/foo': 7.0,
                     '/bar/baz': 2.0
                 },
                 'grad_var': {
                     '/foo': 5 * (2 - 1**2),
                     '/bar/baz': 10 * (6 - 2**2)
                 },
                 'update_var': {
                     '/foo': 5 * (6 - 2**2),
                     '/bar/baz': 10 * (2 - 1**2)
                 },
                 'update_ratio': {
                     '/foo': 5 * (6 - 2**2) / 7.0,
                     '/bar/baz': 10 * (2 - 1**2) / 2.0
                 }
             }))
コード例 #4
0
    def test_init(self):
        """Test the training metrics initializer."""

        zeros_like_params = jax.tree_map(jnp.zeros_like, self.mock_params0)
        zeros_scalar_like_params = jax.tree_map(lambda x: 0.0,
                                                self.mock_params0)
        zeros_timeseries = jnp.zeros(self.num_train_steps)
        zeros_timeseries_like_params = jax.tree_map(
            lambda x: jnp.zeros(self.num_train_steps), self.mock_params0)

        # Test init with everything disabled.
        init_fn, _, _ = make_training_metrics(self.num_train_steps)
        initial_metrics_state = init_fn(self.mock_params0)
        self.assertTrue(
            pytree_equal({'param_norm': zeros_scalar_like_params},
                         initial_metrics_state))

        # Test init with enable_ema = True and enable_train_cost=True.
        init_fn, _, _ = make_training_metrics(self.num_train_steps,
                                              enable_ema=True,
                                              enable_train_cost=True,
                                              enable_param_norms=True,
                                              enable_gradient_norm=True,
                                              enable_update_norm=True,
                                              enable_update_norms=True)
        initial_metrics_state = init_fn(self.mock_params0)
        self.assertTrue(
            pytree_equal(
                initial_metrics_state, {
                    'train_cost': zeros_timeseries,
                    'param_norm': zeros_scalar_like_params,
                    'grad_ema': zeros_like_params,
                    'grad_sq_ema': zeros_like_params,
                    'update_ema': zeros_like_params,
                    'update_sq_ema': zeros_like_params,
                    'param_norms': zeros_timeseries_like_params,
                    'gradient_norm': zeros_timeseries,
                    'update_norm': zeros_timeseries,
                    'update_norms': zeros_timeseries_like_params
                }))
コード例 #5
0
ファイル: test_checkpoint.py プロジェクト: google/init2winit
    def test_save_load_roundtrip(self):
        """Test that saving and loading produces the original state."""
        baz = ['a', 'b', 'ccc']
        state = dict(params=self.params,
                     global_step=5,
                     completed_epochs=4,
                     baz=baz)
        checkpoint.save_checkpoint(self.test_dir, 0, state)
        latest = checkpoint.load_latest_checkpoint(self.test_dir, target=state)

        self.assertEqual(latest['baz'], baz)
        assert pytree_equal(latest['params'], self.params)
        self.assertEqual(latest['global_step'], 5)
        self.assertEqual(latest['completed_epochs'], 4)
コード例 #6
0
    def test_update_update_norms(self):
        """Ensure that we update gradient and update norms correctly."""
        init_fn, update_fn, _ = make_training_metrics(
            self.num_train_steps,
            enable_gradient_norm=True,
            enable_update_norm=True,
            enable_update_norms=True)
        initial_metrics_state = init_fn(self.mock_params0)
        updated_metrics_state = update_fn(initial_metrics_state, 0,
                                          self.mock_cost0, self.mock_grad1,
                                          self.mock_params0, self.mock_params1,
                                          self.mock_optimizer_state0)
        updated_metrics_state = update_fn(updated_metrics_state, 1,
                                          self.mock_cost1, self.mock_grad2,
                                          self.mock_params1, self.mock_params2,
                                          self.mock_optimizer_state1)
        self.assertTrue(
            pytree_equal(
                updated_metrics_state['update_norms'], {
                    'foo':
                    jnp.array([
                        self.step_size *
                        jnp.linalg.norm(self.mock_grad1['foo']),
                        self.step_size *
                        jnp.linalg.norm(self.mock_grad2['foo']), 0.0, 0.0, 0.0
                    ]),
                    'bar': {
                        'baz':
                        jnp.array([
                            self.step_size *
                            jnp.linalg.norm(self.mock_grad1['bar']['baz']),
                            self.step_size *
                            jnp.linalg.norm(self.mock_grad2['bar']['baz']),
                            0.0, 0.0, 0.0
                        ])
                    }
                }))

        self.assertEqual(updated_metrics_state['update_norm'][0],
                         total_tree_norm_l2(self.mock_grad1))
        self.assertEqual(updated_metrics_state['update_norm'][1],
                         total_tree_norm_l2(self.mock_grad2))

        self.assertEqual(updated_metrics_state['update_norm'][0],
                         self.step_size * total_tree_norm_l2(self.mock_grad1))
        self.assertEqual(updated_metrics_state['update_norm'][1],
                         self.step_size * total_tree_norm_l2(self.mock_grad2))
コード例 #7
0
    def test_train_cost(self):
        """Ensure that the train cost is logged correctly."""
        init_fn, update_fn, _ = make_training_metrics(self.num_train_steps,
                                                      enable_train_cost=True)
        initial_metrics_state = init_fn(self.mock_params0)
        updated_metrics_state = update_fn(initial_metrics_state, 0,
                                          self.mock_cost0, self.mock_grad1,
                                          self.mock_params0, self.mock_params1,
                                          self.mock_optimizer_state0)
        updated_metrics_state = update_fn(updated_metrics_state, 1,
                                          self.mock_cost1, self.mock_grad2,
                                          self.mock_params1, self.mock_params2,
                                          self.mock_optimizer_state1)

        self.assertTrue(
            pytree_equal(
                updated_metrics_state['train_cost'],
                jnp.array([self.mock_cost0, self.mock_cost1, 0.0, 0.0, 0.0])))
コード例 #8
0
    def test_update_grad_ema(self):
        """Ensure that the training metrics updater updates grad ema correctly."""

        init_fn, update_fn, _ = make_training_metrics(self.num_train_steps,
                                                      enable_ema=True,
                                                      ema_beta=0.5)
        initial_metrics_state = init_fn(self.mock_params0)
        updated_metrics_state = update_fn(initial_metrics_state, 0,
                                          self.mock_cost0, self.mock_grad1,
                                          self.mock_params0, self.mock_params1,
                                          self.mock_optimizer_state0)
        updated_metrics_state = update_fn(updated_metrics_state, 1,
                                          self.mock_cost1, self.mock_grad2,
                                          self.mock_params1, self.mock_params2,
                                          self.mock_optimizer_state1)

        self.assertTrue(
            pytree_equal(
                updated_metrics_state['grad_ema'],
                jax.tree_map(lambda x, y, z: 0.25 * x + 0.25 * y + 0.5 * z,
                             self.mock_zeros, self.mock_grad1,
                             self.mock_grad2)))
コード例 #9
0
ファイル: test_checkpoint.py プロジェクト: google/init2winit
    def test_replicate_and_maybe_restore_from_checkpoint_logic(self):
        """Test that the right checkpoint is returned.

      1.  If no external_checkpoint_path was passed, and if there is no
      latest checkpoint in the train_dir, then the function should return
      the passed-in params, batch_stats, etc.
      2.  If an external checkpoint was provided but no latest checkpoint
      exists in the train_dir, then the function should return the external
      checkpoint.
      3.  If a latest checkpoint exists in the train dir, then the function
      should return that checkpoint.

      In the interest of conciseness, this test only checks the params,
      not the batch_stats, optimizer_state, or training_metics.  The below test
      test_all_variables_restored() covers the other three.
    """
        # mock parameters.
        initial_params = {'foo': 1.0}
        latest_params = {'foo': 2.0}
        external_params = {'foo': 3.0}

        fresh_train_dir = tempfile.mkdtemp()
        external_dir = tempfile.mkdtemp()

        # two helper functions
        def save_checkpoint(train_dir, global_step, preemption_count,
                            sum_train_cost, params):
            """Helper function to save a checkpoint."""

            checkpoint.save_checkpoint(train_dir=train_dir,
                                       step=global_step,
                                       state=dict(
                                           global_step=global_step,
                                           preemption_count=preemption_count,
                                           sum_train_cost=sum_train_cost,
                                           optimizer_state={},
                                           params=params,
                                           batch_stats={},
                                           training_metrics_grabber={}),
                                       max_to_keep=1)

        def maybe_restore_checkpoint(params, train_dir,
                                     external_checkpoint_path):
            """Helper function to replicate_and_maybe_restore a checkpoint."""

            (_, ret_params, _, _, ret_global_step, ret_sum_train_cost,
             ret_preemption_count, ret_is_restored
             ) = checkpoint.replicate_and_maybe_restore_checkpoint(
                 {}, params, {}, {}, train_dir, external_checkpoint_path)

            ret_params_unrep = jax.device_get(
                jax_utils.unreplicate(ret_params))

            return (ret_params_unrep, ret_global_step, ret_sum_train_cost,
                    ret_preemption_count, ret_is_restored)

        # Save external checkpoint.
        save_checkpoint(train_dir=external_dir,
                        global_step=5,
                        preemption_count=4,
                        sum_train_cost=7.0,
                        params=external_params)
        external_checkpoint_path = os.path.join(external_dir, 'ckpt_' + str(5))

        # If no latest checkpoint exists, and no external checkpoint was provided,
        # the function should return the passed-in params.

        (ret_params, ret_global_step, ret_sum_train_cost, ret_preemption_count,
         ret_is_restored) = maybe_restore_checkpoint(initial_params,
                                                     fresh_train_dir, None)

        self.assertEqual(ret_preemption_count, 0)
        self.assertEqual(ret_global_step, 0)
        self.assertEqual(ret_sum_train_cost, 0.0)
        self.assertFalse(ret_is_restored)
        assert pytree_equal(ret_params, initial_params)

        # If no latest checkpoint exists, and an external checkpoint was provided,
        # the function should return the external checkpoint.

        (ret_params, ret_global_step, ret_sum_train_cost, ret_preemption_count,
         ret_is_restored) = maybe_restore_checkpoint(initial_params,
                                                     fresh_train_dir,
                                                     external_checkpoint_path)

        self.assertEqual(ret_preemption_count, 4)
        self.assertEqual(ret_global_step, 5)
        self.assertEqual(ret_sum_train_cost, 7.0)
        self.assertFalse(ret_is_restored)
        assert pytree_equal(ret_params, external_params)

        # Save latest checkpoint.
        save_checkpoint(train_dir=fresh_train_dir,
                        global_step=10,
                        preemption_count=2,
                        sum_train_cost=2.2,
                        params=latest_params)

        # If a latest checkpoint exists, then even if an external checkpoint was
        # provided, the function should return the latest checkpoint.

        (ret_params, ret_global_step, ret_sum_train_cost, ret_preemption_count,
         ret_is_restored) = maybe_restore_checkpoint(initial_params,
                                                     fresh_train_dir,
                                                     external_checkpoint_path)

        self.assertEqual(ret_preemption_count, 2)
        self.assertEqual(ret_global_step, 10)
        self.assertEqual(ret_sum_train_cost, 2.2)
        self.assertTrue(ret_is_restored)
        assert pytree_equal(ret_params, latest_params)

        shutil.rmtree(fresh_train_dir)
        shutil.rmtree(external_dir)
コード例 #10
0
ファイル: test_checkpoint.py プロジェクト: google/init2winit
    def test_all_variables_restored(self):
        """Test that all variables are properly restored.

    This test checks that optimizer_state, params, batch_stats, and
    training_metrics_grabber are all properly restored after training
    is pre-empted.
    """

        fresh_train_dir = tempfile.mkdtemp()
        global_step = 100
        preemption_count = 8
        sum_train_cost = 0.9

        saved_optimizer_state = {'second_moments': 7}
        saved_params = {'kernel': 3}
        saved_batch_stats = {'mean': 2}
        saved_training_metrics = {'ema': 4}

        initial_optimizer_state = {'second_moments': 0}
        initial_params = {'kernel': 0}
        initial_batch_stats = {'mean': 0}
        initial_training_metrics = {'ema': 0}

        checkpoint.save_checkpoint(
            train_dir=fresh_train_dir,
            step=global_step,
            state=dict(global_step=global_step,
                       preemption_count=preemption_count,
                       sum_train_cost=sum_train_cost,
                       optimizer_state=saved_optimizer_state,
                       params=saved_params,
                       batch_stats=saved_batch_stats,
                       training_metrics_grabber=saved_training_metrics),
            max_to_keep=1)

        (
            ret_state,
            ret_params,
            ret_batch_stats,
            ret_training_metrics,
            ret_global_step,
            ret_sum_train_cost,
            ret_preemption_count,
            ret_is_restored,
        ) = checkpoint.replicate_and_maybe_restore_checkpoint(
            initial_optimizer_state, initial_params, initial_batch_stats,
            initial_training_metrics, fresh_train_dir)

        assert pytree_equal(jax.device_get(jax_utils.unreplicate(ret_state)),
                            saved_optimizer_state)
        assert pytree_equal(jax.device_get(jax_utils.unreplicate(ret_params)),
                            saved_params)
        assert pytree_equal(
            jax.device_get(jax_utils.unreplicate(ret_batch_stats)),
            saved_batch_stats)
        assert pytree_equal(
            jax.device_get(jax_utils.unreplicate(ret_training_metrics)),
            saved_training_metrics)
        self.assertEqual(ret_sum_train_cost, sum_train_cost)
        self.assertEqual(ret_preemption_count, preemption_count)
        self.assertEqual(ret_global_step, global_step)
        self.assertEqual(ret_is_restored, True)

        shutil.rmtree(fresh_train_dir)