Exemple #1
0
def load_state_dict(base_directory) -> Dict[str, Any]:
    """Restores `state` as dictionary from the latest checkpoint.

  Synopsis:

    data = checkpoint.load_state_dict(base_directory)
    params = data['optimizer']['target']['params']
    module = mnist_lib.MyArchitecture.partial(num_classes=10)
    model = flax.nn.Model(module, params)

  Args:
    base_directory: Directory from which the checkpoints should be restored. See
      `Checkpoint.__init__()`.

  Returns:
    The deserialized Flax data, as a dictionary.

  Raises:
    FileNotFoundError: If there is no checkpoint to restore.
  """
    ckpt = Checkpoint(base_directory)
    if not ckpt.latest_checkpoint:
        raise FileNotFoundError(f"No checkpoint found in {base_directory}")
    with utils.log_activity("load_state_dict"):
        with tf.io.gfile.GFile(ckpt.latest_checkpoint_flax, "rb") as f:
            return flax.serialization.msgpack_restore(f.read())
Exemple #2
0
 def test_log_activity_fails(self, ):
     with self.assertRaises(TestError):  # pylint: disable=g-error-prone-assert-raises, line-too-long
         with self.assertLogs() as logs:
             with utils.log_activity("test_activity"):
                 raise TestError()
     self.assertLen(logs.output, 2)
     self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
     self.assertRegex(logs.output[1],
                      r"^ERROR:absl:test_activity FAILED after \d+.\d\ds")
Exemple #3
0
 def test_log_activity(self, ):
     with self.assertLogs() as logs:
         with utils.log_activity("test_activity"):
             pass
     self.assertLen(logs.output, 2)
     self.assertEqual(logs.output[0], "INFO:absl:test_activity ...")
     self.assertRegex(
         logs.output[1],
         r"^INFO:absl:test_activity finished after \d+.\d\ds.$")