def test_numpy_serialization(self): normal_dtypes = [ 'byte', 'b', 'ubyte', 'short', 'h', 'ushort', 'i', 'uint', 'intp', 'p', 'uintp', 'long', 'l', 'longlong', 'q', 'ulonglong', 'half', 'e', 'f', 'double', 'd', 'longdouble', 'g', 'cfloat', 'cdouble', 'clongdouble', 'm', 'bool8', 'b1', 'int64', 'i8', 'uint64', 'u8', 'float16', 'f2', 'float32', 'f4', 'float64', 'f8', 'float128', 'f16', 'complex64', 'c8', 'complex128', 'c16', 'complex256', 'c32', 'm8', 'int32', 'i4', 'uint32', 'u4', 'int16', 'i2', 'uint16', 'u2', 'int8', 'i1', 'uint8', 'u1', 'complex_', 'int0', 'uint0', 'single', 'csingle', 'singlecomplex', 'float_', 'intc', 'uintc', 'int_', 'longfloat', 'clongfloat', 'longcomplex', 'bool_', 'int', 'float', 'complex', 'bool' ] onp.random.seed(0) for dtype in normal_dtypes: v = onp.random.uniform(-100, 100, size=()).astype(dtype)[()] restored_v = serialization.msgpack_restore( serialization.msgpack_serialize(v)) self.assertEqual(restored_v.dtype, v.dtype) onp.testing.assert_array_equal(restored_v, v) for shape in [(), (5, ), (10, 10), (1, 20, 30, 1)]: arr = onp.random.uniform(-100, 100, size=shape).astype(dtype) restored_arr = serialization.msgpack_restore( serialization.msgpack_serialize(arr)) self.assertEqual(restored_arr.dtype, arr.dtype) onp.testing.assert_array_equal(restored_arr, arr)
def test_jax_numpy_serialization(self): jax_dtypes = [ jnp.bool_, jnp.uint8, jnp.uint16, jnp.uint32, jnp.int8, jnp.int16, jnp.int32, jnp.bfloat16, jnp.float16, jnp.float32, jnp.complex64 ] for dtype in jax_dtypes: v = jnp.array(onp.random.uniform(-100, 100, size=())).astype(dtype)[()] restored_v = serialization.msgpack_restore( serialization.msgpack_serialize(v)) self.assertEqual(restored_v.dtype, v.dtype) onp.testing.assert_array_equal(restored_v, v) for shape in [(), (5, ), (10, 10), (1, 20, 30, 1)]: arr = jnp.array(onp.random.uniform(-100, 100, size=shape)).astype(dtype) restored_arr = serialization.msgpack_restore( serialization.msgpack_serialize(arr)) self.assertEqual(restored_arr.dtype, arr.dtype) onp.testing.assert_array_equal(restored_arr, arr)
def test_serialization_chunking3(self): old_chunksize = serialization.MAX_CHUNK_SIZE serialization.MAX_CHUNK_SIZE = 91 * 8 try: tmp = {'a': np.ones((10, 10))} tmpbytes = serialization.msgpack_serialize(tmp) newtmp = serialization.msgpack_restore(tmpbytes) finally: serialization.MAX_CHUNK_SIZE = old_chunksize jax.tree_multimap(np.testing.assert_array_equal, tmp, newtmp)
def per_host_sum_fs(in_tree, step): """Execute sum on in_tree's leaves across each host. Data is shared via the filesystem. Args: in_tree: pytree w. array leaves. step: int: step number for marking temporary files. Returns: out_tree w. same shape as in_tree, result of sum across in_trees from each host. """ def fname(step, host_id): return os.path.join(FLAGS.model_dir, f'partial_bleu_{step}_{host_id}') # Write this host's data to filesystem. logging.info('saving partial bleu stats: %s', fname(step, jax.host_id())) with tf.io.gfile.GFile(fname(step, jax.host_id()), 'wb') as fp: fp.write(serialization.msgpack_serialize(list(in_tree))) # Load other hosts' data by polling filesystem for known files. results = {k: None for k in jax.host_ids()} results[jax.host_id()] = tuple(in_tree) while not all(results.values()): unfinished = [k for k in results if results[k] is None] for host_id in unfinished: # If file exists, read contents. if tf.io.gfile.exists(fname(step, host_id)): with tf.io.gfile.GFile(fname(step, host_id), 'rb') as fp: data = fp.read() try: res = serialization.msgpack_restore(data) results[host_id] = tuple(res) # Catch incomplete written file edgecase and continue looping. except msgpack.exceptions.UnpackValueError: pass time.sleep(1) # Return sum-aggregated partial bleu statistics. return functools.reduce(lambda x, y: jax.tree_multimap(np.add, x, y), results.values())
def test_complex_serialization(self): for x in [1j, 1 + 2j]: restored_x = serialization.msgpack_restore( serialization.msgpack_serialize(x)) self.assertEqual(x, restored_x)