示例#1
0
    def test_numpy_serialization(self):
        normal_dtypes = [
            'byte', 'b', 'ubyte', 'short', 'h', 'ushort', 'i', 'uint', 'intp',
            'p', 'uintp', 'long', 'l', 'longlong', 'q', 'ulonglong', 'half',
            'e', 'f', 'double', 'd', 'longdouble', 'g', 'cfloat', 'cdouble',
            'clongdouble', 'm', 'bool8', 'b1', 'int64', 'i8', 'uint64', 'u8',
            'float16', 'f2', 'float32', 'f4', 'float64', 'f8', 'float128',
            'f16', 'complex64', 'c8', 'complex128', 'c16', 'complex256', 'c32',
            'm8', 'int32', 'i4', 'uint32', 'u4', 'int16', 'i2', 'uint16', 'u2',
            'int8', 'i1', 'uint8', 'u1', 'complex_', 'int0', 'uint0', 'single',
            'csingle', 'singlecomplex', 'float_', 'intc', 'uintc', 'int_',
            'longfloat', 'clongfloat', 'longcomplex', 'bool_', 'int', 'float',
            'complex', 'bool'
        ]
        onp.random.seed(0)
        for dtype in normal_dtypes:
            v = onp.random.uniform(-100, 100, size=()).astype(dtype)[()]
            restored_v = serialization.msgpack_restore(
                serialization.msgpack_serialize(v))
            self.assertEqual(restored_v.dtype, v.dtype)
            onp.testing.assert_array_equal(restored_v, v)

            for shape in [(), (5, ), (10, 10), (1, 20, 30, 1)]:
                arr = onp.random.uniform(-100, 100, size=shape).astype(dtype)
                restored_arr = serialization.msgpack_restore(
                    serialization.msgpack_serialize(arr))
                self.assertEqual(restored_arr.dtype, arr.dtype)
                onp.testing.assert_array_equal(restored_arr, arr)
示例#2
0
def _restore_from_checkpoint(model, checkpoint_file):
    with tf.io.gfile.GFile(checkpoint_file, 'rb') as fp:
        checkpoint = serialization.msgpack_restore(fp.read())
        if 'target' not in checkpoint:
            raise ValueError('Invalid checkpoint %s: no top-level "target".' %
                             checkpoint_file)
        checkpoint_model = checkpoint['target']
        checkpoint_model = jax.tree_map(jnp.array, checkpoint_model)
        return serialization.from_state_dict(model, checkpoint_model)
示例#3
0
    def test_jax_numpy_serialization(self):
        jax_dtypes = [
            jnp.bool_, jnp.uint8, jnp.uint16, jnp.uint32, jnp.int8, jnp.int16,
            jnp.int32, jnp.bfloat16, jnp.float16, jnp.float32, jnp.complex64
        ]
        for dtype in jax_dtypes:
            v = jnp.array(onp.random.uniform(-100, 100,
                                             size=())).astype(dtype)[()]
            restored_v = serialization.msgpack_restore(
                serialization.msgpack_serialize(v))
            self.assertEqual(restored_v.dtype, v.dtype)
            onp.testing.assert_array_equal(restored_v, v)

            for shape in [(), (5, ), (10, 10), (1, 20, 30, 1)]:
                arr = jnp.array(onp.random.uniform(-100, 100,
                                                   size=shape)).astype(dtype)
                restored_arr = serialization.msgpack_restore(
                    serialization.msgpack_serialize(arr))
                self.assertEqual(restored_arr.dtype, arr.dtype)
                onp.testing.assert_array_equal(restored_arr, arr)
示例#4
0
    def test_serialization_chunking3(self):
        old_chunksize = serialization.MAX_CHUNK_SIZE
        serialization.MAX_CHUNK_SIZE = 91 * 8
        try:
            tmp = {'a': np.ones((10, 10))}
            tmpbytes = serialization.msgpack_serialize(tmp)
            newtmp = serialization.msgpack_restore(tmpbytes)
        finally:
            serialization.MAX_CHUNK_SIZE = old_chunksize

        jax.tree_multimap(np.testing.assert_array_equal, tmp, newtmp)
示例#5
0
    def test_restore_chunked(self):
        old_chunksize = serialization.MAX_CHUNK_SIZE
        serialization.MAX_CHUNK_SIZE = 91 * 8
        try:
            tmp = np.random.uniform(-100, 100, size=(21, 37))
            serialized = serialization.to_bytes(tmp)
            restored = serialization.msgpack_restore(serialized)
        finally:
            serialization.MAX_CHUNK_SIZE = old_chunksize

        np.testing.assert_array_equal(restored, tmp)
示例#6
0
    def test_restore_unchunked(self):
        """Check if mgspack_restore works for unchunked inputs."""
        def msgpack_serialize_legacy(pytree):
            """Old implementation that was not chunking."""
            return msgpack.packb(pytree,
                                 default=serialization._msgpack_ext_pack,
                                 strict_types=True)

        tmp = np.random.uniform(-100, 100, size=(21, 37))
        serialized = msgpack_serialize_legacy(tmp)
        old_chunksize = serialization.MAX_CHUNK_SIZE
        serialization.MAX_CHUNK_SIZE = 91 * 8
        try:
            restored = serialization.msgpack_restore(serialized)
        finally:
            serialization.MAX_CHUNK_SIZE = old_chunksize

        np.testing.assert_array_equal(restored, tmp)
示例#7
0
def per_host_sum_fs(in_tree, step):
    """Execute sum on in_tree's leaves across each host.

  Data is shared via the filesystem.

  Args:
    in_tree: pytree w. array leaves.
    step: int: step number for marking temporary files.

  Returns:
    out_tree w. same shape as in_tree, result of sum across in_trees
    from each host.
  """
    def fname(step, host_id):
        return os.path.join(FLAGS.model_dir, f'partial_bleu_{step}_{host_id}')

    # Write this host's data to filesystem.
    logging.info('saving partial bleu stats: %s', fname(step, jax.host_id()))
    with tf.io.gfile.GFile(fname(step, jax.host_id()), 'wb') as fp:
        fp.write(serialization.msgpack_serialize(list(in_tree)))
    # Load other hosts' data by polling filesystem for known files.
    results = {k: None for k in jax.host_ids()}
    results[jax.host_id()] = tuple(in_tree)
    while not all(results.values()):
        unfinished = [k for k in results if results[k] is None]
        for host_id in unfinished:
            # If file exists, read contents.
            if tf.io.gfile.exists(fname(step, host_id)):
                with tf.io.gfile.GFile(fname(step, host_id), 'rb') as fp:
                    data = fp.read()
                try:
                    res = serialization.msgpack_restore(data)
                    results[host_id] = tuple(res)
                # Catch incomplete written file edgecase and continue looping.
                except msgpack.exceptions.UnpackValueError:
                    pass
        time.sleep(1)
    # Return sum-aggregated partial bleu statistics.
    return functools.reduce(lambda x, y: jax.tree_multimap(np.add, x, y),
                            results.values())
示例#8
0
def restore_checkpoint(ckpt_dir, target, step=None, prefix='checkpoint_'):
    """Restore last/best checkpoint from checkpoints in path.

  Sorts the checkpoint files naturally, returning the highest-valued
  file, e.g.:
    ckpt_1, ckpt_2, ckpt_3 --> ckpt_3
    ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1
    ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5

  Args:
    ckpt_dir: str: directory of checkpoints to restore from.
    target: matching object to rebuild via deserialized state-dict. If None,
      the deserialized state-dict is returned as-is.
    step: int: step number to load or None to load latest.
    prefix: str: name prefix of checkpoint files.

  Returns:
    Restored `target` updated from checkpoint file, or if no step specified and
    no checkpoint files present, returns the passed-in `target` unchanged.
  """
    if step:
        ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
        if not gfile.exists(ckpt_path):
            raise ValueError(f'Matching checkpoint not found: {ckpt_path}')
    else:
        glob_path = os.path.join(ckpt_dir, f'{prefix}*')
        checkpoint_files = natural_sort(gfile.glob(glob_path))
        ckpt_tmp_path = _checkpoint_path(ckpt_dir, 'tmp', prefix)
        checkpoint_files = [f for f in checkpoint_files if f != ckpt_tmp_path]
        if not checkpoint_files:
            return target
        ckpt_path = checkpoint_files[-1]

    logging.info('Restoring checkpoint from %s', ckpt_path)
    with gfile.GFile(ckpt_path, 'rb') as fp:
        if target is None:
            return serialization.msgpack_restore(fp.read())
        else:
            return serialization.from_bytes(target, fp.read())
示例#9
0
 def test_complex_serialization(self):
     for x in [1j, 1 + 2j]:
         restored_x = serialization.msgpack_restore(
             serialization.msgpack_serialize(x))
         self.assertEqual(x, restored_x)
示例#10
0
def restore_checkpoint(ckpt_dir,
                       target,
                       step=None,
                       prefix='checkpoint_',
                       parallel=True):
    """Restore last/best checkpoint from checkpoints in path.

  Sorts the checkpoint files naturally, returning the highest-valued
  file, e.g.:
    ckpt_1, ckpt_2, ckpt_3 --> ckpt_3
    ckpt_0.01, ckpt_0.1, ckpt_0.001 --> ckpt_0.1
    ckpt_-1.0, ckpt_1.0, ckpt_1e5 --> ckpt_1e5

  Args:
    ckpt_dir: str: checkpoint file or directory of checkpoints to restore from.
    target: matching object to rebuild via deserialized state-dict. If None,
      the deserialized state-dict is returned as-is.
    step: int: step number to load or None to load latest. If specified,
      ckpt_dir must be a directory.
    prefix: str: name prefix of checkpoint files.
    parallel: bool: whether to load seekable checkpoints in parallel, for speed.

  Returns:
    Restored `target` updated from checkpoint file, or if no step specified and
    no checkpoint files present, returns the passed-in `target` unchanged.
    If a file path is specified and is not found, the passed-in `target` will be
    returned. This is to match the behavior of the case where a directory path
    is specified but the directory has not yet been created.
  """
    if step:
        ckpt_path = _checkpoint_path(ckpt_dir, step, prefix)
        if not gfile.exists(ckpt_path):
            raise ValueError(f'Matching checkpoint not found: {ckpt_path}')
    else:
        if gfile.isdir(ckpt_dir):
            ckpt_path = latest_checkpoint(ckpt_dir, prefix)
            if not ckpt_path:
                logging.info(f'Found no checkpoint files in {ckpt_dir}')
                return target
        else:
            ckpt_path = ckpt_dir
            if not gfile.exists(ckpt_path):
                logging.info(f'Found no checkpoint file at {ckpt_path}')
                return target

    logging.info('Restoring checkpoint from %s', ckpt_path)
    with gfile.GFile(ckpt_path, 'rb') as fp:
        if parallel and fp.seekable():
            buf_size = 128 << 20  # 128M buffer.
            num_bufs = fp.size() / buf_size
            logging.debug('num_bufs: %d', num_bufs)
            checkpoint_contents = bytearray(fp.size())

            def read_chunk(i):
                # NOTE: We have to re-open the file to read each chunk, otherwise the
                # parallelism has no effect. But we could reuse the file pointers
                # within each thread.
                with gfile.GFile(ckpt_path, 'rb') as f:
                    f.seek(i * buf_size)
                    buf = f.read(buf_size)
                    if buf:
                        checkpoint_contents[i * buf_size:i * buf_size +
                                            len(buf)] = buf
                    return len(buf) / buf_size

            pool_size = 32
            pool = thread.ThreadPoolExecutor(pool_size)
            results = pool.map(read_chunk, range(int(num_bufs) + 1))
            results = list(results)
            pool.shutdown(wait=False)
            logging.debug('results: %s', results)
        else:
            checkpoint_contents = fp.read()

        if target is None:
            return serialization.msgpack_restore(checkpoint_contents)
        else:
            return serialization.from_bytes(target, checkpoint_contents)