def test_multitenancy(self): addr = DEFAULT_HOSTNAME port = common.find_free_port() # Use noqa to silence flake8. # Need to store in an unused variable here to ensure the first # object is not destroyed before the second object is created. store1 = dist.TCPStore(addr, port, 1, True, multi_tenant=True) # type: ignore[call-arg] # noqa: F841 store2 = dist.TCPStore(addr, port, 1, True, multi_tenant=True) # type: ignore[call-arg] # noqa: F841
def test_address_already_in_use(self): with self.assertRaisesRegex(RuntimeError, "^Address already in use$"): addr = 'localhost' port = common.find_free_port() # Use noqa to silence flake8. # Need to store in an unused variable here to ensure the first # object is not destroyed before the second object is created. store1 = c10d.TCPStore(addr, port, True) # noqa: F841 store2 = c10d.TCPStore(addr, port, True) # noqa: F841
def test_address_already_in_use(self): err_msg_reg = "^The server socket has failed to listen on any local " with self.assertRaisesRegex(RuntimeError, err_msg_reg): addr = DEFAULT_HOSTNAME port = common.find_free_port() # Use noqa to silence flake8. # Need to store in an unused variable here to ensure the first # object is not destroyed before the second object is created. store1 = dist.TCPStore(addr, port, 1, True) # noqa: F841 store2 = dist.TCPStore(addr, port, 1, True) # noqa: F841
def test_address_already_in_use(self): if sys.platform == "win32": err_msg_reg = "Only one usage of each socket address*" else: err_msg_reg = "^Address already in use$" with self.assertRaisesRegex(RuntimeError, err_msg_reg): addr = DEFAULT_HOSTNAME port = common.find_free_port() # Use noqa to silence flake8. # Need to store in an unused variable here to ensure the first # object is not destroyed before the second object is created. store1 = dist.TCPStore(addr, port, 1, True) # noqa: F841 store2 = dist.TCPStore(addr, port, 1, True) # noqa: F841
def init_distrib(world_rank, world_size, backend: str = "nccl", port_offset: int = 0): assert torch.distributed.is_available( ), "torch.distributed must be available" if "GLOO_SOCKET_IFNAME" not in os.environ: os.environ["GLOO_SOCKET_IFNAME"] = get_ifname() if "NCCL_SOCKET_IFNAME" not in os.environ: os.environ["NCCL_SOCKET_IFNAME"] = get_ifname() master_port = int( os.environ.get( "MASTER_PORT", DEFAULT_PORT + int(SLURM_JOBID if SLURM_JOBID is not None else 0) % 127 + port_offset, )) master_addr = os.environ.get("MASTER_ADDR", DEFAULT_MASTER_ADDR) tcp_store = distrib.TCPStore(master_addr, master_port, world_size, world_rank == 0) distrib.init_process_group(backend, store=tcp_store, rank=world_rank, world_size=world_size) return tcp_store
def test_dist_broadcast_coalesced(self): # Set up process group. store = c10d.TCPStore('localhost', self.port, self.is_master) options = c10d.ProcessGroupGloo.Options() options.devices = [ c10d.ProcessGroupGloo.create_tcp_device(interface="lo") ] process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) device = torch.device('cuda') target = torch.arange(10, dtype=torch.float64, device=device).chunk(5) if self.is_master: # All processes should have these tensors in the end. tensors = target else: # Non-master processes start with empty tensors and should be # filled with the tensors from the master. tensors = torch.zeros(10, device=device).chunk(5) c10d._dist_broadcast_coalesced(tensors, buffer_size=10, process_group=process_group) if not self.is_master: self.assertEqual(tensors, target)
def test_sync_params_no_buffers(self): # Set up process group. store = c10d.TCPStore('localhost', self.port, self.is_master) options = c10d.ProcessGroupGloo.Options() options.devices = [ c10d.ProcessGroupGloo.create_tcp_device(interface="lo") ] process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) # Use all available devices on every process here (data is small, so should be fine). devices = gpus_for_rank(self.world_size)[self.rank] target = torch.arange(10, dtype=torch.float64, device='cuda:0').chunk(5) parameter_data = [target] parameter_data += [ torch.zeros(10, device=torch.device('cuda', d)).chunk(5) for d in devices[1:] ] buffer_data = [[]] * len(parameter_data) c10d._sync_params(process_group, parameter_data=parameter_data, buffer_data=buffer_data, devices=devices, broadcast_bucket_size=10, broadcast_buffers=False) for device_data in parameter_data: for i, parameter in enumerate(device_data): self.assertEqual(parameter, target[i])
def test_fp16(self): store = c10d.TCPStore('localhost', self.port, self.rank == 0) process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) gpus = gpus_for_rank(self.world_size)[self.rank] model = nn.Linear(1, 1, bias=False).cuda(gpus[0]).half() nn.init.constant_(model.weight, 1) ddp_model = DistributedDataParallel( model, device_ids=[gpus[0]], process_group=process_group, bucket_cap_mb=1, ) # Input 2**15, so that the gradients will overflow with a # world_size of 2, unless we normalize the gradient by the # world_size before the reduction input = torch.Tensor([[2**15]]).cuda(gpus[0]).half() # Step model ddp_model.train() output = ddp_model(input) loss = output.sum() loss.backward() self.assertFalse( any(torch.isinf(p.grad).any() for p in ddp_model.parameters()))
def test_nccl_backend(self): store = c10d.TCPStore('localhost', self.port, self.is_master) process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) gpus = gpus_for_rank(self.world_size)[self.rank] self._test_ddp_with_process_group(process_group, gpus) self._test_ddp_with_process_group( process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus)))
def create_tcp_store(addr="localhost", world_size=1, is_master=True, timeout=timedelta(minutes=5), wait_for_workers=True, jit_class=False): """ Creates a TCP store. Retries if the chosen port is already in use. """ port = find_free_port() if jit_class: timeout_millisecond = int(timeout / timedelta(milliseconds=1)) return torch.classes.dist_c10d.TCPStore(addr, port, world_size, is_master, timeout_millisecond) else: return c10d.TCPStore(addr, port, world_size, is_master, wait_for_workers=wait_for_workers)
def _create_client(self, index, addr, port, world_size): client_store = dist.TCPStore(addr, port, world_size, timeout=timedelta(seconds=10)) self.assertEqual("value".encode(), client_store.get("key")) client_store.set(f"new_key{index}", f"new_value{index}") self.assertEqual( f"next_value{index}".encode(), client_store.compare_set(f"new_key{index}", f"new_value{index}", f"next_value{index}"))
def _create_client(self, index, addr, port, world_size, messages): try: client_store = dist.TCPStore(addr, port, world_size, timeout=timedelta(seconds=10)) self.assertEqual("value".encode(), client_store.get("key")) client_store.set(f"new_key{index}", f"new_value{index}") self.assertEqual(f"next_value{index}".encode(), client_store.compare_set(f"new_key{index}", f"new_value{index}", f"next_value{index}")) except Exception: messages.put('Caught exception: \n{}exiting process with exit code: {}' .format(traceback.format_exc(), MultiProcessTestCase.TEST_ERROR_EXIT_CODE)) sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE)
def _nccl_init(self, nccl_addr, nccl_ip, nccl_port): self.nccl_ip, self.nccl_addr, self.nccl_port = nccl_ip, nccl_addr, nccl_port print('Rank {} calling init_process_group. Addr: {}'.format(self.rank, nccl_addr)) # from https://github.com/pytorch/pytorch/blob/master/test/simulate_nccl_errors.py store = dist.TCPStore(self.nccl_ip, self.nccl_port, self.nb_learners, self.rank == 0) process_group = dist.ProcessGroupNCCL(store, self.rank, self.nb_learners) print('Rank {} initialized process group.'.format(self.rank)) process_group.barrier() print('Rank {} process group barrier finished.'.format(self.rank)) self.process_group = process_group # set optimizer process_group self.optimizer.set_process_group(self.process_group)
def init_process(self, rank, args, kwargs): dist.init_process_group(self.backend, rank=rank, world_size=self.world_size) store = dist.TCPStore("127.0.0.1", 1234, 2, rank == 0, timedelta(seconds=30)) trainer = self.cls(*args, **kwargs) store.set("shared", str(json.dumps(RecDict(trainer.state.shared)))) trainer._store = store trainer._dist = dist trainer._rank = rank trainer.run()
def create_tcp_store(addr): """ Creates a TCP store. Retries if the chosen port is already in use. """ while True: try: port = common.find_free_port() return c10d.TCPStore(addr, port, True) except RuntimeError as error: if str(error) == "Address already in use": continue raise
def test_gloo_backend(self): store = c10d.TCPStore('localhost', self.port, self.is_master) options = c10d.ProcessGroupGloo.Options() options.devices = [ c10d.ProcessGroupGloo.create_tcp_device(interface="lo") ] process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) gpus = gpus_for_rank(self.world_size)[self.rank] self._test_ddp_with_process_group(process_group, gpus) self._test_ddp_with_process_group( process_group, list(map(lambda i: torch.device('cuda:' + str(i)), gpus)))
def init_distrib_slurm( backend: str = "nccl", ) -> Tuple[int, torch.distributed.TCPStore]: r"""Initializes torch.distributed by parsing environment variables set by SLURM when `srun` is used or by parsing environment variables set by torch.distributed.launch :param backend: Which torch.distributed backend to use :returns: Tuple of the local_rank (aka which GPU to use for this process) and the TCPStore used for the rendezvous """ assert (torch.distributed.is_available() ), "torch.distributed must be available" if "GLOO_SOCKET_IFNAME" not in os.environ: os.environ["GLOO_SOCKET_IFNAME"] = get_ifname() if "NCCL_SOCKET_IFNAME" not in os.environ: os.environ["NCCL_SOCKET_IFNAME"] = get_ifname() master_port = int(os.environ.get("MASTER_PORT", DEFAULT_PORT)) master_addr = os.environ.get("MASTER_ADDR", DEFAULT_MASTER_ADDR) # Check to see if we should parse from torch.distributed.launch if os.environ.get("LOCAL_RANK", None) is not None: local_rank = int(os.environ["LOCAL_RANK"]) world_rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) # Else parse from SLURM is using SLURM elif os.environ.get("SLURM_JOBID", None) is not None: local_rank = int(os.environ["SLURM_LOCALID"]) world_rank = int(os.environ["SLURM_PROCID"]) world_size = int(os.environ["SLURM_NTASKS"]) # Otherwise setup for just 1 process, this is nice for testing else: local_rank = 0 world_rank = 0 world_size = 1 # Default port to initialized the TCP store on master_port += 3 # Default address of world rank 0 master_addr = "127.0.0.3" tcp_store = distrib.TCPStore(master_addr, master_port, world_size, world_rank == 0) distrib.init_process_group(backend, store=tcp_store, rank=world_rank, world_size=world_size) return local_rank, tcp_store
def init_process(self, rank, args, kwargs, mode): dist.init_process_group(self.backend, rank=rank, world_size=self.world_size) store = dist.TCPStore(os.environ['STORE_ADDR'], int(os.environ['STORE_PORT']), len(self.world), rank == 0, timedelta(seconds=30)) self.__trainer._dist = dist self.__trainer._rank = rank self.__trainer._mode = mode self.__trainer._world_size = len(self.world) store.set("test", str(json.dumps(RecDict(self.__trainer._metrics.test)))) self.__trainer._store = store self.__trainer.run(*args, **kwargs)
def test_sync_params_with_buffers(self): # Set up process group. store = c10d.TCPStore('localhost', self.port, self.is_master) options = c10d.ProcessGroupGloo.Options() options.devices = [ c10d.ProcessGroupGloo.create_tcp_device(interface="lo") ] process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) devices = gpus_for_rank(self.world_size)[self.rank] target = torch.arange(10, dtype=torch.float64, device='cuda:0').chunk(5) parameter_data = [target] parameter_data += [ torch.zeros(10, device=torch.device('cuda', d)).chunk(5) for d in devices[1:] ] # sync_params should do a dist_broadcast for buffers, so we only populate the master buffers and # then check that other processes' tensors end up matching. if self.is_master: buffer_data = [target] buffer_data += [ torch.zeros(10, device=torch.device('cuda', d)).chunk(5) for d in devices[1:] ] else: buffer_data = [ torch.zeros(10, device=torch.device('cuda', d)).chunk(5) for d in devices ] c10d._sync_params(process_group, parameter_data=parameter_data, buffer_data=buffer_data, devices=devices, broadcast_bucket_size=10, broadcast_buffers=True) for device_data in parameter_data: for i, parameter in enumerate(device_data): self.assertEqual(parameter, target[i]) for device_data in buffer_data: for i, buffer in enumerate(device_data): self.assertEqual(buffer, target[i])
def init_distrib_slurm( backend: str = "nccl", ) -> Tuple[int, torch.distributed.TCPStore]: # type: ignore r"""Initializes torch.distributed by parsing environment variables set by SLURM when ``srun`` is used or by parsing environment variables set by torch.distributed.launch :param backend: Which torch.distributed backend to use :returns: Tuple of the local_rank (aka which GPU to use for this process) and the TCPStore used for the rendezvous """ assert (torch.distributed.is_available() ), "torch.distributed must be available" if "GLOO_SOCKET_IFNAME" not in os.environ: os.environ["GLOO_SOCKET_IFNAME"] = get_ifname() if "NCCL_SOCKET_IFNAME" not in os.environ: os.environ["NCCL_SOCKET_IFNAME"] = get_ifname() local_rank, world_rank, world_size = get_distrib_size() master_addr = os.environ.get("MASTER_ADDR", DEFAULT_MASTER_ADDR) master_port = int(os.environ.get("MASTER_PORT", DEFAULT_PORT)) if SLURM_JOBID is not None: master_port += int(SLURM_JOBID) % int( os.environ.get("MASTER_PORT_RANGE", DEFAULT_PORT_RANGE)) if MULTI_PROC_OFFSET is not None: master_port += int(MULTI_PROC_OFFSET) tcp_store = distrib.TCPStore( # type: ignore master_addr, master_port, world_size, world_rank == 0) distrib.init_process_group(backend, store=tcp_store, rank=world_rank, world_size=world_size) return local_rank, tcp_store
def init_distrib_slurm(backend="nccl"): if "GLOO_SOCKET_IFNAME" not in os.environ: os.environ["GLOO_SOCKET_IFNAME"] = get_ifname() if "NCCL_SOCKET_IFNAME" not in os.environ: os.environ["NCCL_SOCKET_IFNAME"] = get_ifname() master_port = int(os.environ.get("MASTER_PORT", 8738)) master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1") local_rank = int( os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", 0))) world_rank = int(os.environ.get("RANK", os.environ.get("SLURM_PROCID", 0))) world_size = int( os.environ.get("WORLD_SIZE", os.environ.get("SLURM_NTASKS", 1))) tcp_store = distrib.TCPStore(master_addr, master_port, world_size, world_rank == 0) distrib.init_process_group(backend, store=tcp_store, rank=world_rank, world_size=world_size) return local_rank, tcp_store
def test_gloo_backend(self): store = c10d.TCPStore('localhost', self.port, self.is_master) options = c10d.ProcessGroupGloo.Options() options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")] process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options) self._test_ddp_with_process_group(process_group)
def test_nccl_backend(self): store = c10d.TCPStore('localhost', self.port, self.is_master) process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size) self._test_ddp_with_process_group(process_group)
def setUp(self): addr = 'localhost' port = common.find_free_port() self.tcpstore = c10d.TCPStore(addr, port, True) self.prefix = "test_prefix" self.tcpstore.set_timeout(timedelta(seconds=300))
def _create_store(self): addr = 'localhost' port = common.find_free_port() store = c10d.TCPStore(addr, port, True) store.set_timeout(timedelta(seconds=300)) return store
def create_c10d_store( is_server: bool, server_addr: str, server_port: int = -1, world_size: int = 1, timeout: float = (60 * 10), # 10 min wait_for_workers: bool = True, retries=3, ): if server_port == -1 and world_size > 1: raise ValueError( f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}" ) if server_port != -1: log.info(f"sever_port: {server_port}, specified, ignoring retries") # only retry when server_port is NOT static attempt = retries if server_port == -1 else 1 while True: if server_port != -1: port = server_port else: port = get_free_port() log.info( f"Creating c10d store on {server_addr}:{port}\n" f" world_size : {world_size}\n" f" is_server : {is_server}\n" f" timeout(sec): {timeout}\n" ) try: store = dist.TCPStore( host_name=server_addr, port=port, world_size=world_size, is_master=is_server, timeout=datetime.timedelta(seconds=timeout), wait_for_workers=wait_for_workers, ) # skips full rank check when we don't have to wait for all workers if wait_for_workers: _check_full_rank(store, world_size) log.info("Successfully created c10d store") return store except RuntimeError as e: # this is brittle, but the underlying exception type is not properly pybinded # so we parse the error msg for now, interestingly this is how torch itself # detects timeouts and port conflicts in their own unittests # see - caffe2/torch/testing/_internal/common_utils.py # TODO properly map the exceptions in pybind (c10d/init.cpp) if str(e) == _ADDRESS_IN_USE: # this will only happen on the server if attempt < retries: log.warning( f"port: {port} already in use, attempt: [{attempt}/{retries}]" ) attempt += 1 else: raise RuntimeError( f"on {server_addr}, port: {port} already in use" ) from e else: raise
def test_address_already_in_use(self): with self.assertRaisesRegex(RuntimeError, "^Address already in use$"): addr = 'localhost' port = common.find_free_port() store1 = c10d.TCPStore(addr, port, True) store2 = c10d.TCPStore(addr, port, True)
parser = argparse.ArgumentParser( description='Simple script to simulate NCCL errors. The script is ' 'supposed to be run on multiple different nodes simultaneously with ' 'appropriate rank and world_size. The script run an allreduce() on ' 'the rank 0 node and aborts all the other nodes to simulate an error ' 'in NCCL') parser.add_argument('addr', help='address of the master node to connect to.') parser.add_argument('port', help='port of the master node to connect to.') parser.add_argument('rank', help='rank of this node') parser.add_argument('world_size', help='number of nodes in process group') args = parser.parse_args() rank = int(args.rank) world_size = int(args.world_size) port = int(args.port) store = c10d.TCPStore(args.addr, port, world_size, rank == 0) process_group = c10d.ProcessGroupNCCL(store, rank, world_size) logging.info('Running first allreduce') process_group.allreduce(torch.rand(10).cuda(rank)).wait() if rank == 0: logging.info('Running second allreduce only on rank 0') work = process_group.allreduce(torch.rand(10).cuda(rank)) logging.info('Waiting for allreduce to complete...') work.wait() logging.info('Second allreduce successful: {}'.format( work.is_success())) else: logging.info('Aborting all other ranks.') os.abort()
def _worker_fn(world_rank: int, world_size: int, port: int, unused_params: bool): device = (torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")) tcp_store = distrib.TCPStore( # type: ignore "127.0.0.1", port, world_size, world_rank == 0) distrib.init_process_group("gloo", store=tcp_store, rank=world_rank, world_size=world_size) config = get_config("habitat_baselines/config/test/ppo_pointnav_test.yaml") obs_space = gym.spaces.Dict({ IntegratedPointGoalGPSAndCompassSensor.cls_uuid: gym.spaces.Box( low=np.finfo(np.float32).min, high=np.finfo(np.float32).max, shape=(2, ), dtype=np.float32, ) }) action_space = ActionSpace({"move": EmptySpace()}) actor_critic = PointNavBaselinePolicy.from_config(config, obs_space, action_space) # This use adds some arbitrary parameters that aren't part of the computation # graph, so they will mess up DDP if they aren't correctly ignored by it if unused_params: actor_critic.unused = nn.Linear(64, 64) actor_critic.to(device=device) ppo_cfg = config.RL.PPO agent = DDPPO( actor_critic=actor_critic, clip_param=ppo_cfg.clip_param, ppo_epoch=ppo_cfg.ppo_epoch, num_mini_batch=ppo_cfg.num_mini_batch, value_loss_coef=ppo_cfg.value_loss_coef, entropy_coef=ppo_cfg.entropy_coef, lr=ppo_cfg.lr, eps=ppo_cfg.eps, max_grad_norm=ppo_cfg.max_grad_norm, use_normalized_advantage=ppo_cfg.use_normalized_advantage, ) agent.init_distributed() rollouts = RolloutStorage( ppo_cfg.num_steps, 2, obs_space, action_space, ppo_cfg.hidden_size, num_recurrent_layers=actor_critic.net.num_recurrent_layers, is_double_buffered=False, ) rollouts.to(device) for k, v in rollouts.buffers["observations"].items(): rollouts.buffers["observations"][k] = torch.randn_like(v) # Add two steps so batching works rollouts.advance_rollout() rollouts.advance_rollout() # Get a single batch batch = next(rollouts.recurrent_generator(rollouts.buffers["returns"], 1)) # Call eval actions through the internal wrapper that is used in # agent.update value, action_log_probs, dist_entropy, _ = agent._evaluate_actions( batch["observations"], batch["recurrent_hidden_states"], batch["prev_actions"], batch["masks"], batch["actions"], ) # Backprop on things (value.mean() + action_log_probs.mean() + dist_entropy.mean()).backward() # Make sure all ranks have very similar parameters for param in actor_critic.parameters(): if param.grad is not None: grads = [param.grad.detach().clone() for _ in range(world_size)] distrib.all_gather(grads, grads[world_rank]) for i in range(world_size): assert torch.isclose(grads[i], grads[world_rank]).all()
def main(argv: List[str]) -> None: """Script entry point. Parameters ---------- argv: list[str] List of CLI arguments. Returns ------- None """ # Parse CLI arguments. args = parse_args(argv=argv) # `args.batch_size` validation. lmp.util.validate.raise_if_wrong_ordered( vals=[1, args.batch_size], val_names=['1', 'args.batch_size']) # `args.first_ckpt` validation. lmp.util.validate.raise_if_wrong_ordered( vals=[-1, args.first_ckpt], val_names=['-1', 'args.first_ckpt']) # `args.last_ckpt` validation. lmp.util.validate.raise_if_wrong_ordered( vals=[-1, args.last_ckpt], val_names=['-1', 'args.last_ckpt']) # `args.n_worker` validation. lmp.util.validate.raise_if_wrong_ordered( vals=[0, args.n_worker, len(os.sched_getaffinity(0))], val_names=['0', 'args.n_worker', 'number of available CPUs'], ) lmp.util.validate.raise_if_wrong_ordered( vals=[args.n_worker, args.batch_size], val_names=['args.n_worker', 'args.batch_size'], ) # We use TCP to perform RPC. Timeout is set to 5 minutes. store = dist.TCPStore( is_master=args.rank == HOST_RANK, host_name=args.host_name, port=args.host_port, timeout=timedelta(minutes=5), world_size=args.world_size, ) # Use NCCL backend to perform CUDA collectives. dist.init_process_group( backend=dist.Backend.NCCL, store=store, rank=args.rank, timeout=timedelta(minutes=5), world_size=args.world_size, ) # Sync arguments. dist_args_k = [ 'host_name', 'host_port', 'local_rank', 'rank', 'world_size' ] for k in args.__dict__.keys(): if k in dist_args_k: continue # Host broadcast arguments. if args.rank == HOST_RANK: store.set(k, str(args.__dict__[k])) # Non-host receive host arguments. else: v = store.get(k) if isinstance(args.__dict__[k], str): args.__dict__[k] = v.decode('utf-8') else: args.__dict__[k] = type(args.__dict__[k])(v) # Set random seed for reproducibility. Note that each process use different seed to get different slice of batch. lmp.util.rand.set_seed(seed=args.seed + args.rank) # Get model running device. device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device(f'cuda:{args.local_rank}') # Load pre-trained model configuration. model_cfg = lmp.util.cfg.load(exp_name=args.exp_name) # Load pre-trained tokenizer instance. tknzr = lmp.util.tknzr.load(exp_name=model_cfg.tknzr_exp_name) # Get dataset instance and convert samples to tensor. if args.is_dset_in_memory: dset: torch.utils.data.Dataset = lmp.util.dset.FastTensorDset( dset=lmp.util.dset.load(**args.__dict__), max_seq_len=model_cfg.max_seq_len, tknzr=tknzr, ) else: dset = lmp.util.dset.SlowTensorDset( dset=lmp.util.dset.load(**args.__dict__), max_seq_len=model_cfg.max_seq_len, tknzr=tknzr, ) dset_size = len(dset) # Mini-batch sampler. Each process will get batches exclusive to itself. dist_sampler = torch.utils.data.distributed.DistributedSampler( num_replicas=args.world_size, rank=args.rank, dataset=dset, shuffle=False, ) # Mini-batch distributed random sampler. Only when `args.n_worker > 0` we set `persisten_worker = True`. We set # `pin_memory = True` to speed up process (which only speed up a few seconds). data_loader = torch.utils.data.DataLoader( batch_size=args.batch_size // args.world_size, dataset=dset, num_workers=args.n_worker, persistent_workers=bool(args.n_worker != 0), pin_memory=True, sampler=dist_sampler, ) # Get tensorboard logger instance. Only main process need to log performance. if args.rank == HOST_RANK: writer = lmp.util.log.get_tb_logger(exp_name=args.exp_name) else: writer = None # Evaluate checkpoints within ranges. for ckpt in lmp.util.model.list_ckpts(exp_name=args.exp_name, first_ckpt=args.first_ckpt, last_ckpt=args.last_ckpt): # Load pre-trained model instance. model = lmp.util.model.load(ckpt=ckpt, exp_name=args.exp_name) # Set model to evaluation model. This turn off dropout layers in model. model = model.eval() # Move model to running device. model = model.to(device) # Create DDP model. dpp_model = torch.nn.parallel.DistributedDataParallel(model) # Processes can have unevenly distributed number of batch. Thus one must use `ddp_model.join()` to avoid dead lock. with dpp_model.join(): # Record average perplexity. avg_ppl = 0.0 for batch_tkids in tqdm(data_loader): # Encode text into token ids. We convert token ids into tensor and move to the same running device as model. batch_tkids = batch_tkids.to(device) # Format batch token ids to satisfy language model training format. batch_cur_tkids = batch_tkids[..., :-1] batch_next_tkids = batch_tkids[..., 1:] # Loop over token ids to get next token id prediction probability distribution. batch_prev_states = None batch_tkids_pd = [] for i in range(batch_cur_tkids.size(1)): batch_next_tkids_pd, batch_prev_states = model.pred( batch_cur_tkids=batch_cur_tkids[:, i], batch_prev_states=batch_prev_states, ) # Collect prediction probability distribution. batch_tkids_pd.append(batch_next_tkids_pd) # Calculate perplexity. batch_ppl = lmp.util.metric.ppl(batch_tkids=batch_next_tkids, batch_tkids_pd=torch.stack( batch_tkids_pd, dim=1)) # Sum `batch_ppl` from each process. dist.all_reduce(batch_ppl, op=dist.ReduceOp.SUM) # Accumulate average perplexity. avg_ppl += (batch_ppl / dset_size).sum().item() # Log average perplexity on dataset to CLI and tensorboard. Only main process need to log performance. if args.rank == HOST_RANK: writer.add_scalar(f'ppl/{args.dset_name}/{args.ver}', avg_ppl, ckpt) print(f'checkpoint: {ckpt}, avg ppl: {avg_ppl}') # Free memory. This is only need for unit test. del args del avg_ppl del batch_cur_tkids del batch_next_tkids del batch_next_tkids_pd del batch_ppl del batch_prev_states del batch_tkids del batch_tkids_pd del ckpt del data_loader del device del dset del dset_size del model del model_cfg del tknzr del writer torch.cuda.empty_cache() gc.collect()