Ejemplo n.º 1
0
 def test_plasma_store_full_raises(self):
     with tempfile.NamedTemporaryFile() as new_path:
         server = PlasmaStore.start(path=new_path.name, nbytes=10000)
         with self.assertRaises(plasma.PlasmaStoreFull):
             # 2000 floats is more than 2000 bytes
             PlasmaView(
                 np.random.rand(10000, 1), dummy_path, 1, plasma_path=new_path.name
             )
         server.kill()
Ejemplo n.º 2
0
 def test_two_servers_do_not_share_object_id_space(self):
     data_server_1 = np.array([0, 1])
     data_server_2 = np.array([2, 3])
     server_2_path = self.path
     with tempfile.NamedTemporaryFile() as server_1_path:
         server = PlasmaStore.start(path=server_1_path.name, nbytes=10000)
         arr1 = PlasmaView(
             data_server_1, dummy_path, 1, plasma_path=server_1_path.name
         )
         assert len(arr1.client.list()) == 1
         assert (arr1.array == data_server_1).all()
         arr2 = PlasmaView(data_server_2, dummy_path, 1, plasma_path=server_2_path)
         assert (arr2.array == data_server_2).all()
         assert (arr1.array == data_server_1).all()
         server.kill()
Ejemplo n.º 3
0
def cli_main(
    modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
) -> None:
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser, modify_parser=modify_parser)

    cfg = convert_namespace_to_omegaconf(args)

    if cfg.common.use_plasma_view:
        server = PlasmaStore(path=cfg.common.plasma_path)
        logger.info(f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}")

    if args.profile:
        with torch.cuda.profiler.profile():
            with torch.autograd.profiler.emit_nvtx():
                distributed_utils.call_main(cfg, main)
    else:
        distributed_utils.call_main(cfg, main)
Ejemplo n.º 4
0
 def setUp(self) -> None:
     self.tmp_file = tempfile.NamedTemporaryFile()  # noqa: P201
     self.path = self.tmp_file.name
     self.server = PlasmaStore.start(path=self.path)
     self.client = plasma.connect(self.path, num_retries=10)