def load_state_dict(base_directory) -> Dict[str, Any]: """Restores `state` as dictionary from the latest checkpoint. Synopsis: data = checkpoint.load_state_dict(base_directory) params = data['optimizer']['target']['params'] module = mnist_lib.MyArchitecture.partial(num_classes=10) model = flax.nn.Model(module, params) Args: base_directory: Directory from which the checkpoints should be restored. See `Checkpoint.__init__()`. Returns: The deserialized Flax data, as a dictionary. Raises: FileNotFoundError: If there is no checkpoint to restore. """ ckpt = Checkpoint(base_directory) if not ckpt.latest_checkpoint: raise FileNotFoundError(f"No checkpoint found in {base_directory}") with utils.log_activity("load_state_dict"): with tf.io.gfile.GFile(ckpt.latest_checkpoint_flax, "rb") as f: return flax.serialization.msgpack_restore(f.read())
def test_log_activity_fails(self, ): with self.assertRaises(TestError): # pylint: disable=g-error-prone-assert-raises, line-too-long with self.assertLogs() as logs: with utils.log_activity("test_activity"): raise TestError() self.assertLen(logs.output, 2) self.assertEqual(logs.output[0], "INFO:absl:test_activity ...") self.assertRegex(logs.output[1], r"^ERROR:absl:test_activity FAILED after \d+.\d\ds")
def test_log_activity(self, ): with self.assertLogs() as logs: with utils.log_activity("test_activity"): pass self.assertLen(logs.output, 2) self.assertEqual(logs.output[0], "INFO:absl:test_activity ...") self.assertRegex( logs.output[1], r"^INFO:absl:test_activity finished after \d+.\d\ds.$")