예제 #1
0
  def test_load_latest_checkpoint(self):
    state_1 = test_util.create_mock_state(seed=1)
    state_2 = test_util.create_mock_state(seed=2)
    checkpoint.save_checkpoint(self._root_dir, state_1, round_num=10, keep=2)
    checkpoint.save_checkpoint(self._root_dir, state_2, round_num=3, keep=2)

    latest_state, latest_round_num = checkpoint.load_latest_checkpoint(
        root_dir=self._root_dir)

    self.assertEqual(latest_round_num, 10)
    tf.nest.map_structure(self.assertAllEqual, latest_state, state_1)
예제 #2
0
 def test_l2_regularizer_evaluation_with_center(self):
     params = test_util.create_mock_state(seed=0).params
     original_output = regularizers.L2Regularizer()(params)
     center_params = jax.tree_map(lambda l: l * 0.2, params)
     output = regularizers.L2Regularizer(
         center_params=center_params)(params)
     self.assertAlmostEqual(output, original_output * (1. - 0.2)**2)
예제 #3
0
  def test_save_state(self):
    temp_dir = self.get_temp_dir()
    path = os.path.join(temp_dir, 'state')
    state = test_util.create_mock_state()

    serialization.save_state(state, path)

    self.assertTrue(tf.io.gfile.exists(path))
예제 #4
0
    def test_l2_regularizer_parameter_weight(self):
        params = test_util.create_mock_state(seed=0).params
        original_output = regularizers.L2Regularizer()(params)

        param_weights = jax.tree_map(
            lambda leaf: 2 * jax.numpy.ones(leaf.shape), params)
        output = regularizers.L2Regularizer(
            weight=1.0, param_weights=param_weights)(params)
        self.assertAlmostEqual(output, 2 * original_output, delta=1e-5)
예제 #5
0
  def test_save_checkpoint_keep(self):
    state = test_util.create_mock_state()

    for i in range(3):
      checkpoint.save_checkpoint(self._root_dir, state, round_num=i, keep=2)

    self.assertCountEqual(
        tf.io.gfile.listdir(self._root_dir),
        ['checkpoint_00000001', 'checkpoint_00000002'])
예제 #6
0
  def test_load_state(self):
    temp_dir = self.get_temp_dir()
    path = os.path.join(temp_dir, 'state')
    init_state = test_util.create_mock_state()
    serialization.save_state(init_state, path)

    state = serialization.load_state(path)

    expected_flat, expected_tree_def = jax.tree_flatten(init_state)
    actual_flat, actual_tree_def = jax.tree_flatten(state)
    for expected_array, actual_array in zip(expected_flat, actual_flat):
      self.assertAllEqual(expected_array, actual_array)
    self.assertEqual(expected_tree_def, actual_tree_def)
예제 #7
0
 def test_l2_regularizer_weight(self):
     params = test_util.create_mock_state(seed=0).params
     original_output = regularizers.L2Regularizer()(params)
     output = regularizers.L2Regularizer(weight=0.2)(params)
     self.assertAlmostEqual(output, original_output * 0.2)
예제 #8
0
 def test_l2_regularizer(self):
     params = test_util.create_mock_state(seed=0).params
     output = regularizers.L2Regularizer()(params)
     self.assertAlmostEqual(output, 37.64189)
예제 #9
0
  def test_save_checkpoint(self):
    checkpoint.save_checkpoint(self._root_dir, test_util.create_mock_state())

    self.assertEqual(
        tf.io.gfile.listdir(self._root_dir), ['checkpoint_00000000'])
예제 #10
0
 def test_create_mock_state(self):
   state = test_util.create_mock_state(seed=0)
   self.assertEqual(list(state.params.keys()), ['linear'])
   self.assertEqual(hk.data_structures.tree_size(state.params), 124)