def test_numpy_serialization(self):
        normal_dtypes = [
            'byte', 'b', 'ubyte', 'short', 'h', 'ushort', 'i', 'uint', 'intp',
            'p', 'uintp', 'long', 'l', 'longlong', 'q', 'ulonglong', 'half',
            'e', 'f', 'double', 'd', 'longdouble', 'g', 'cfloat', 'cdouble',
            'clongdouble', 'm', 'bool8', 'b1', 'int64', 'i8', 'uint64', 'u8',
            'float16', 'f2', 'float32', 'f4', 'float64', 'f8', 'float128',
            'f16', 'complex64', 'c8', 'complex128', 'c16', 'complex256', 'c32',
            'm8', 'int32', 'i4', 'uint32', 'u4', 'int16', 'i2', 'uint16', 'u2',
            'int8', 'i1', 'uint8', 'u1', 'complex_', 'int0', 'uint0', 'single',
            'csingle', 'singlecomplex', 'float_', 'intc', 'uintc', 'int_',
            'longfloat', 'clongfloat', 'longcomplex', 'bool_', 'int', 'float',
            'complex', 'bool'
        ]
        onp.random.seed(0)
        for dtype in normal_dtypes:
            v = onp.random.uniform(-100, 100, size=()).astype(dtype)[()]
            restored_v = serialization.msgpack_restore(
                serialization.msgpack_serialize(v))
            self.assertEqual(restored_v.dtype, v.dtype)
            onp.testing.assert_array_equal(restored_v, v)

            for shape in [(), (5, ), (10, 10), (1, 20, 30, 1)]:
                arr = onp.random.uniform(-100, 100, size=shape).astype(dtype)
                restored_arr = serialization.msgpack_restore(
                    serialization.msgpack_serialize(arr))
                self.assertEqual(restored_arr.dtype, arr.dtype)
                onp.testing.assert_array_equal(restored_arr, arr)
    def test_jax_numpy_serialization(self):
        jax_dtypes = [
            jnp.bool_, jnp.uint8, jnp.uint16, jnp.uint32, jnp.int8, jnp.int16,
            jnp.int32, jnp.bfloat16, jnp.float16, jnp.float32, jnp.complex64
        ]
        for dtype in jax_dtypes:
            v = jnp.array(onp.random.uniform(-100, 100,
                                             size=())).astype(dtype)[()]
            restored_v = serialization.msgpack_restore(
                serialization.msgpack_serialize(v))
            self.assertEqual(restored_v.dtype, v.dtype)
            onp.testing.assert_array_equal(restored_v, v)

            for shape in [(), (5, ), (10, 10), (1, 20, 30, 1)]:
                arr = jnp.array(onp.random.uniform(-100, 100,
                                                   size=shape)).astype(dtype)
                restored_arr = serialization.msgpack_restore(
                    serialization.msgpack_serialize(arr))
                self.assertEqual(restored_arr.dtype, arr.dtype)
                onp.testing.assert_array_equal(restored_arr, arr)
Exemple #3
0
    def test_serialization_chunking3(self):
        old_chunksize = serialization.MAX_CHUNK_SIZE
        serialization.MAX_CHUNK_SIZE = 91 * 8
        try:
            tmp = {'a': np.ones((10, 10))}
            tmpbytes = serialization.msgpack_serialize(tmp)
            newtmp = serialization.msgpack_restore(tmpbytes)
        finally:
            serialization.MAX_CHUNK_SIZE = old_chunksize

        jax.tree_multimap(np.testing.assert_array_equal, tmp, newtmp)
Exemple #4
0
def per_host_sum_fs(in_tree, step):
    """Execute sum on in_tree's leaves across each host.

  Data is shared via the filesystem.

  Args:
    in_tree: pytree w. array leaves.
    step: int: step number for marking temporary files.

  Returns:
    out_tree w. same shape as in_tree, result of sum across in_trees
    from each host.
  """
    def fname(step, host_id):
        return os.path.join(FLAGS.model_dir, f'partial_bleu_{step}_{host_id}')

    # Write this host's data to filesystem.
    logging.info('saving partial bleu stats: %s', fname(step, jax.host_id()))
    with tf.io.gfile.GFile(fname(step, jax.host_id()), 'wb') as fp:
        fp.write(serialization.msgpack_serialize(list(in_tree)))
    # Load other hosts' data by polling filesystem for known files.
    results = {k: None for k in jax.host_ids()}
    results[jax.host_id()] = tuple(in_tree)
    while not all(results.values()):
        unfinished = [k for k in results if results[k] is None]
        for host_id in unfinished:
            # If file exists, read contents.
            if tf.io.gfile.exists(fname(step, host_id)):
                with tf.io.gfile.GFile(fname(step, host_id), 'rb') as fp:
                    data = fp.read()
                try:
                    res = serialization.msgpack_restore(data)
                    results[host_id] = tuple(res)
                # Catch incomplete written file edgecase and continue looping.
                except msgpack.exceptions.UnpackValueError:
                    pass
        time.sleep(1)
    # Return sum-aggregated partial bleu statistics.
    return functools.reduce(lambda x, y: jax.tree_multimap(np.add, x, y),
                            results.values())
 def test_complex_serialization(self):
     for x in [1j, 1 + 2j]:
         restored_x = serialization.msgpack_restore(
             serialization.msgpack_serialize(x))
         self.assertEqual(x, restored_x)