def test_numpy_serialization(self): normal_dtypes = [ 'byte', 'b', 'ubyte', 'short', 'h', 'ushort', 'i', 'uint', 'intp', 'p', 'uintp', 'long', 'l', 'longlong', 'q', 'ulonglong', 'half', 'e', 'f', 'double', 'd', 'longdouble', 'g', 'cfloat', 'cdouble', 'clongdouble', 'm', 'bool8', 'b1', 'int64', 'i8', 'uint64', 'u8', 'float16', 'f2', 'float32', 'f4', 'float64', 'f8', 'float128', 'f16', 'complex64', 'c8', 'complex128', 'c16', 'complex256', 'c32', 'm8', 'int32', 'i4', 'uint32', 'u4', 'int16', 'i2', 'uint16', 'u2', 'int8', 'i1', 'uint8', 'u1', 'complex_', 'int0', 'uint0', 'single', 'csingle', 'singlecomplex', 'float_', 'intc', 'uintc', 'int_', 'longfloat', 'clongfloat', 'longcomplex', 'bool_', 'int', 'float', 'complex', 'bool' ] onp.random.seed(0) for dtype in normal_dtypes: v = onp.random.uniform(-100, 100, size=()).astype(dtype)[()] restored_v = serialization.msgpack_restore( serialization.msgpack_serialize(v)) self.assertEqual(restored_v.dtype, v.dtype) onp.testing.assert_array_equal(restored_v, v) for shape in [(), (5, ), (10, 10), (1, 20, 30, 1)]: arr = onp.random.uniform(-100, 100, size=shape).astype(dtype) restored_arr = serialization.msgpack_restore( serialization.msgpack_serialize(arr)) self.assertEqual(restored_arr.dtype, arr.dtype) onp.testing.assert_array_equal(restored_arr, arr)
def _restore_from_checkpoint(model, checkpoint_file): with tf.io.gfile.GFile(checkpoint_file, 'rb') as fp: checkpoint = serialization.msgpack_restore(fp.read()) if 'target' not in checkpoint: raise ValueError('Invalid checkpoint %s: no top-level "target".' % checkpoint_file) checkpoint_model = checkpoint['target'] checkpoint_model = jax.tree_map(jnp.array, checkpoint_model) return serialization.from_state_dict(model, checkpoint_model)
def test_jax_numpy_serialization(self): jax_dtypes = [ jnp.bool_, jnp.uint8, jnp.uint16, jnp.uint32, jnp.int8, jnp.int16, jnp.int32, jnp.bfloat16, jnp.float16, jnp.float32, jnp.complex64 ] for dtype in jax_dtypes: v = jnp.array(onp.random.uniform(-100, 100, size=())).astype(dtype)[()] restored_v = serialization.msgpack_restore( serialization.msgpack_serialize(v)) self.assertEqual(restored_v.dtype, v.dtype) onp.testing.assert_array_equal(restored_v, v) for shape in [(), (5, ), (10, 10), (1, 20, 30, 1)]: arr = jnp.array(onp.random.uniform(-100, 100, size=shape)).astype(dtype) restored_arr = serialization.msgpack_restore( serialization.msgpack_serialize(arr)) self.assertEqual(restored_arr.dtype, arr.dtype) onp.testing.assert_array_equal(restored_arr, arr)
def test_serialization_chunking3(self): old_chunksize = serialization.MAX_CHUNK_SIZE serialization.MAX_CHUNK_SIZE = 91 * 8 try: tmp = {'a': np.ones((10, 10))} tmpbytes = serialization.msgpack_serialize(tmp) newtmp = serialization.msgpack_restore(tmpbytes) finally: serialization.MAX_CHUNK_SIZE = old_chunksize jax.tree_multimap(np.testing.assert_array_equal, tmp, newtmp)
def test_restore_chunked(self): old_chunksize = serialization.MAX_CHUNK_SIZE serialization.MAX_CHUNK_SIZE = 91 * 8 try: tmp = np.random.uniform(-100, 100, size=(21, 37)) serialized = serialization.to_bytes(tmp) restored = serialization.msgpack_restore(serialized) finally: serialization.MAX_CHUNK_SIZE = old_chunksize np.testing.assert_array_equal(restored, tmp)
def test_restore_unchunked(self): """Check if mgspack_restore works for unchunked inputs.""" def msgpack_serialize_legacy(pytree): """Old implementation that was not chunking.""" return msgpack.packb(pytree, default=serialization._msgpack_ext_pack, strict_types=True) tmp = np.random.uniform(-100, 100, size=(21, 37)) serialized = msgpack_serialize_legacy(tmp) old_chunksize = serialization.MAX_CHUNK_SIZE serialization.MAX_CHUNK_SIZE = 91 * 8 try: restored = serialization.msgpack_restore(serialized) finally: serialization.MAX_CHUNK_SIZE = old_chunksize np.testing.assert_array_equal(restored, tmp)
def per_host_sum_fs(in_tree, step): """Execute sum on in_tree's leaves across each host. Data is shared via the filesystem. Args: in_tree: pytree w. array leaves. step: int: step number for marking temporary files. Returns: out_tree w. same shape as in_tree, result of sum across in_trees from each host. """ def fname(step, host_id): return os.path.join(FLAGS.model_dir, f'partial_bleu_{step}_{host_id}') # Write this host's data to filesystem. logging.info('saving partial bleu stats: %s', fname(step, jax.host_id())) with tf.io.gfile.GFile(fname(step, jax.host_id()), 'wb') as fp: fp.write(serialization.msgpack_serialize(list(in_tree))) # Load other hosts' data by polling filesystem for known files. results = {k: None for k in jax.host_ids()} results[jax.host_id()] = tuple(in_tree) while not all(results.values()): unfinished = [k for k in results if results[k] is None] for host_id in unfinished: # If file exists, read contents. if tf.io.gfile.exists(fname(step, host_id)): with tf.io.gfile.GFile(fname(step, host_id), 'rb') as fp: data = fp.read() try: res = serialization.msgpack_restore(data) results[host_id] = tuple(res) # Catch incomplete written file edgecase and continue looping. except msgpack.exceptions.UnpackValueError: pass time.sleep(1) # Return sum-aggregated partial bleu statistics. return functools.reduce(lambda x, y: jax.tree_multimap(np.add, x, y), results.values())
def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_'): """Restore last/best checkpoint from checkpoints in path. Sorts the checkpoint files naturally, returning the highest-valued file, e.g.: ckpt_1, ckpt_2, ckpt_3 --> ckpt_3 ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1 ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5 Args: ckpt_dir: str: directory of checkpoints to restore from. target: matching object to rebuild via deserialized state-dict. If None, the deserialized state-dict is returned as-is. step: int: step number to load or None to load latest. prefix: str: name prefix of checkpoint files. Returns: Restored `target` updated from checkpoint file, or if no step specified and no checkpoint files present, returns the passed-in `target` unchanged. """ if step: ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) if not gfile.exists(ckpt_path): raise ValueError(f'Matching checkpoint not found: {ckpt_path}') else: glob_path = os.path.join(ckpt_dir, f'{prefix}*') checkpoint_files = natural_sort(gfile.glob(glob_path)) ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix) checkpoint_files = [f for f in checkpoint_files if f != ckpt_tmp_path] if not checkpoint_files: return target ckpt_path = checkpoint_files[-1] logging.info('Restoring checkpoint from %s', ckpt_path) with gfile.GFile(ckpt_path, 'rb') as fp: if target is None: return serialization.msgpack_restore(fp.read()) else: return serialization.from_bytes(target, fp.read())
def test_complex_serialization(self): for x in [1j, 1 + 2j]: restored_x = serialization.msgpack_restore( serialization.msgpack_serialize(x)) self.assertEqual(x, restored_x)
def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_', parallel=True): """Restore last/best checkpoint from checkpoints in path. Sorts the checkpoint files naturally, returning the highest-valued file, e.g.: ckpt_1, ckpt_2, ckpt_3 --> ckpt_3 ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1 ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5 Args: ckpt_dir: str: checkpoint file or directory of checkpoints to restore from. target: matching object to rebuild via deserialized state-dict. If None, the deserialized state-dict is returned as-is. step: int: step number to load or None to load latest. If specified, ckpt_dir must be a directory. prefix: str: name prefix of checkpoint files. parallel: bool: whether to load seekable checkpoints in parallel, for speed. Returns: Restored `target` updated from checkpoint file, or if no step specified and no checkpoint files present, returns the passed-in `target` unchanged. If a file path is specified and is not found, the passed-in `target` will be returned. This is to match the behavior of the case where a directory path is specified but the directory has not yet been created. """ if step: ckpt_path = _checkpoint_path(ckpt_dir, step, prefix) if not gfile.exists(ckpt_path): raise ValueError(f'Matching checkpoint not found: {ckpt_path}') else: if gfile.isdir(ckpt_dir): ckpt_path = latest_checkpoint(ckpt_dir, prefix) if not ckpt_path: logging.info(f'Found no checkpoint files in {ckpt_dir}') return target else: ckpt_path = ckpt_dir if not gfile.exists(ckpt_path): logging.info(f'Found no checkpoint file at {ckpt_path}') return target logging.info('Restoring checkpoint from %s', ckpt_path) with gfile.GFile(ckpt_path, 'rb') as fp: if parallel and fp.seekable(): buf_size = 128 << 20 # 128M buffer. num_bufs = fp.size() / buf_size logging.debug('num_bufs: %d', num_bufs) checkpoint_contents = bytearray(fp.size()) def read_chunk(i): # NOTE: We have to re-open the file to read each chunk, otherwise the # parallelism has no effect. But we could reuse the file pointers # within each thread. with gfile.GFile(ckpt_path, 'rb') as f: f.seek(i * buf_size) buf = f.read(buf_size) if buf: checkpoint_contents[i * buf_size:i * buf_size + len(buf)] = buf return len(buf) / buf_size pool_size = 32 pool = thread.ThreadPoolExecutor(pool_size) results = pool.map(read_chunk, range(int(num_bufs) + 1)) results = list(results) pool.shutdown(wait=False) logging.debug('results: %s', results) else: checkpoint_contents = fp.read() if target is None: return serialization.msgpack_restore(checkpoint_contents) else: return serialization.from_bytes(target, checkpoint_contents)