def test_load_latest_checkpoint(self): state_1 = test_util.create_mock_state(seed=1) state_2 = test_util.create_mock_state(seed=2) checkpoint.save_checkpoint(self._root_dir, state_1, round_num=10, keep=2) checkpoint.save_checkpoint(self._root_dir, state_2, round_num=3, keep=2) latest_state, latest_round_num = checkpoint.load_latest_checkpoint( root_dir=self._root_dir) self.assertEqual(latest_round_num, 10) tf.nest.map_structure(self.assertAllEqual, latest_state, state_1)
def test_l2_regularizer_evaluation_with_center(self): params = test_util.create_mock_state(seed=0).params original_output = regularizers.L2Regularizer()(params) center_params = jax.tree_map(lambda l: l * 0.2, params) output = regularizers.L2Regularizer( center_params=center_params)(params) self.assertAlmostEqual(output, original_output * (1. - 0.2)**2)
def test_save_state(self): temp_dir = self.get_temp_dir() path = os.path.join(temp_dir, 'state') state = test_util.create_mock_state() serialization.save_state(state, path) self.assertTrue(tf.io.gfile.exists(path))
def test_l2_regularizer_parameter_weight(self): params = test_util.create_mock_state(seed=0).params original_output = regularizers.L2Regularizer()(params) param_weights = jax.tree_map( lambda leaf: 2 * jax.numpy.ones(leaf.shape), params) output = regularizers.L2Regularizer( weight=1.0, param_weights=param_weights)(params) self.assertAlmostEqual(output, 2 * original_output, delta=1e-5)
def test_save_checkpoint_keep(self): state = test_util.create_mock_state() for i in range(3): checkpoint.save_checkpoint(self._root_dir, state, round_num=i, keep=2) self.assertCountEqual( tf.io.gfile.listdir(self._root_dir), ['checkpoint_00000001', 'checkpoint_00000002'])
def test_load_state(self): temp_dir = self.get_temp_dir() path = os.path.join(temp_dir, 'state') init_state = test_util.create_mock_state() serialization.save_state(init_state, path) state = serialization.load_state(path) expected_flat, expected_tree_def = jax.tree_flatten(init_state) actual_flat, actual_tree_def = jax.tree_flatten(state) for expected_array, actual_array in zip(expected_flat, actual_flat): self.assertAllEqual(expected_array, actual_array) self.assertEqual(expected_tree_def, actual_tree_def)
def test_l2_regularizer_weight(self): params = test_util.create_mock_state(seed=0).params original_output = regularizers.L2Regularizer()(params) output = regularizers.L2Regularizer(weight=0.2)(params) self.assertAlmostEqual(output, original_output * 0.2)
def test_l2_regularizer(self): params = test_util.create_mock_state(seed=0).params output = regularizers.L2Regularizer()(params) self.assertAlmostEqual(output, 37.64189)
def test_save_checkpoint(self): checkpoint.save_checkpoint(self._root_dir, test_util.create_mock_state()) self.assertEqual( tf.io.gfile.listdir(self._root_dir), ['checkpoint_00000000'])
def test_create_mock_state(self): state = test_util.create_mock_state(seed=0) self.assertEqual(list(state.params.keys()), ['linear']) self.assertEqual(hk.data_structures.tree_size(state.params), 124)