Ejemplo n.º 1
0
    def test_multitenancy(self):
        addr = DEFAULT_HOSTNAME
        port = common.find_free_port()

        # Use noqa to silence flake8.
        # Need to store in an unused variable here to ensure the first
        # object is not destroyed before the second object is created.
        store1 = dist.TCPStore(addr, port, 1, True, multi_tenant=True)  # type: ignore[call-arg] # noqa: F841
        store2 = dist.TCPStore(addr, port, 1, True, multi_tenant=True)  # type: ignore[call-arg] # noqa: F841
Ejemplo n.º 2
0
    def test_address_already_in_use(self):
        with self.assertRaisesRegex(RuntimeError, "^Address already in use$"):
            addr = 'localhost'
            port = common.find_free_port()

            # Use noqa to silence flake8.
            # Need to store in an unused variable here to ensure the first
            # object is not destroyed before the second object is created.
            store1 = c10d.TCPStore(addr, port, True)  # noqa: F841
            store2 = c10d.TCPStore(addr, port, True)  # noqa: F841
Ejemplo n.º 3
0
    def test_address_already_in_use(self):
        err_msg_reg = "^The server socket has failed to listen on any local "
        with self.assertRaisesRegex(RuntimeError, err_msg_reg):
            addr = DEFAULT_HOSTNAME
            port = common.find_free_port()

            # Use noqa to silence flake8.
            # Need to store in an unused variable here to ensure the first
            # object is not destroyed before the second object is created.
            store1 = dist.TCPStore(addr, port, 1, True)  # noqa: F841
            store2 = dist.TCPStore(addr, port, 1, True)  # noqa: F841
Ejemplo n.º 4
0
    def test_address_already_in_use(self):
        if sys.platform == "win32":
            err_msg_reg = "Only one usage of each socket address*"
        else:
            err_msg_reg = "^Address already in use$"
        with self.assertRaisesRegex(RuntimeError, err_msg_reg):
            addr = DEFAULT_HOSTNAME
            port = common.find_free_port()

            # Use noqa to silence flake8.
            # Need to store in an unused variable here to ensure the first
            # object is not destroyed before the second object is created.
            store1 = dist.TCPStore(addr, port, 1, True)  # noqa: F841
            store2 = dist.TCPStore(addr, port, 1, True)  # noqa: F841
Ejemplo n.º 5
0
def init_distrib(world_rank,
                 world_size,
                 backend: str = "nccl",
                 port_offset: int = 0):
    assert torch.distributed.is_available(
    ), "torch.distributed must be available"

    if "GLOO_SOCKET_IFNAME" not in os.environ:
        os.environ["GLOO_SOCKET_IFNAME"] = get_ifname()

    if "NCCL_SOCKET_IFNAME" not in os.environ:
        os.environ["NCCL_SOCKET_IFNAME"] = get_ifname()

    master_port = int(
        os.environ.get(
            "MASTER_PORT",
            DEFAULT_PORT +
            int(SLURM_JOBID if SLURM_JOBID is not None else 0) % 127 +
            port_offset,
        ))
    master_addr = os.environ.get("MASTER_ADDR", DEFAULT_MASTER_ADDR)

    tcp_store = distrib.TCPStore(master_addr, master_port, world_size,
                                 world_rank == 0)
    distrib.init_process_group(backend,
                               store=tcp_store,
                               rank=world_rank,
                               world_size=world_size)

    return tcp_store
Ejemplo n.º 6
0
    def test_dist_broadcast_coalesced(self):
        # Set up process group.
        store = c10d.TCPStore('localhost', self.port, self.is_master)
        options = c10d.ProcessGroupGloo.Options()
        options.devices = [
            c10d.ProcessGroupGloo.create_tcp_device(interface="lo")
        ]
        process_group = c10d.ProcessGroupGloo(store, self.rank,
                                              self.world_size, options)

        device = torch.device('cuda')

        target = torch.arange(10, dtype=torch.float64, device=device).chunk(5)

        if self.is_master:
            # All processes should have these tensors in the end.
            tensors = target
        else:
            # Non-master processes start with empty tensors and should be
            # filled with the tensors from the master.
            tensors = torch.zeros(10, device=device).chunk(5)

        c10d._dist_broadcast_coalesced(tensors,
                                       buffer_size=10,
                                       process_group=process_group)

        if not self.is_master:
            self.assertEqual(tensors, target)
Ejemplo n.º 7
0
    def test_sync_params_no_buffers(self):
        # Set up process group.
        store = c10d.TCPStore('localhost', self.port, self.is_master)
        options = c10d.ProcessGroupGloo.Options()
        options.devices = [
            c10d.ProcessGroupGloo.create_tcp_device(interface="lo")
        ]
        process_group = c10d.ProcessGroupGloo(store, self.rank,
                                              self.world_size, options)

        # Use all available devices on every process here (data is small, so should be fine).
        devices = gpus_for_rank(self.world_size)[self.rank]
        target = torch.arange(10, dtype=torch.float64,
                              device='cuda:0').chunk(5)
        parameter_data = [target]
        parameter_data += [
            torch.zeros(10, device=torch.device('cuda', d)).chunk(5)
            for d in devices[1:]
        ]
        buffer_data = [[]] * len(parameter_data)

        c10d._sync_params(process_group,
                          parameter_data=parameter_data,
                          buffer_data=buffer_data,
                          devices=devices,
                          broadcast_bucket_size=10,
                          broadcast_buffers=False)

        for device_data in parameter_data:
            for i, parameter in enumerate(device_data):
                self.assertEqual(parameter, target[i])
Ejemplo n.º 8
0
    def test_fp16(self):
        store = c10d.TCPStore('localhost', self.port, self.rank == 0)
        process_group = c10d.ProcessGroupNCCL(store, self.rank,
                                              self.world_size)

        gpus = gpus_for_rank(self.world_size)[self.rank]
        model = nn.Linear(1, 1, bias=False).cuda(gpus[0]).half()
        nn.init.constant_(model.weight, 1)
        ddp_model = DistributedDataParallel(
            model,
            device_ids=[gpus[0]],
            process_group=process_group,
            bucket_cap_mb=1,
        )

        # Input 2**15, so that the gradients will overflow with a
        # world_size of 2, unless we normalize the gradient by the
        # world_size before the reduction
        input = torch.Tensor([[2**15]]).cuda(gpus[0]).half()

        # Step model
        ddp_model.train()
        output = ddp_model(input)
        loss = output.sum()
        loss.backward()

        self.assertFalse(
            any(torch.isinf(p.grad).any() for p in ddp_model.parameters()))
Ejemplo n.º 9
0
 def test_nccl_backend(self):
     store = c10d.TCPStore('localhost', self.port, self.is_master)
     process_group = c10d.ProcessGroupNCCL(store, self.rank,
                                           self.world_size)
     gpus = gpus_for_rank(self.world_size)[self.rank]
     self._test_ddp_with_process_group(process_group, gpus)
     self._test_ddp_with_process_group(
         process_group,
         list(map(lambda i: torch.device('cuda:' + str(i)), gpus)))
Ejemplo n.º 10
0
def create_tcp_store(addr="localhost", world_size=1, is_master=True, timeout=timedelta(minutes=5),
                     wait_for_workers=True, jit_class=False):
    """
    Creates a TCP store. Retries if the chosen port is already in use.
    """
    port = find_free_port()
    if jit_class:
        timeout_millisecond = int(timeout / timedelta(milliseconds=1))
        return torch.classes.dist_c10d.TCPStore(addr, port, world_size, is_master, timeout_millisecond)
    else:
        return c10d.TCPStore(addr, port, world_size, is_master, wait_for_workers=wait_for_workers)
Ejemplo n.º 11
0
 def _create_client(self, index, addr, port, world_size):
     client_store = dist.TCPStore(addr,
                                  port,
                                  world_size,
                                  timeout=timedelta(seconds=10))
     self.assertEqual("value".encode(), client_store.get("key"))
     client_store.set(f"new_key{index}", f"new_value{index}")
     self.assertEqual(
         f"next_value{index}".encode(),
         client_store.compare_set(f"new_key{index}", f"new_value{index}",
                                  f"next_value{index}"))
Ejemplo n.º 12
0
 def _create_client(self, index, addr, port, world_size, messages):
     try:
         client_store = dist.TCPStore(addr, port, world_size, timeout=timedelta(seconds=10))
         self.assertEqual("value".encode(), client_store.get("key"))
         client_store.set(f"new_key{index}", f"new_value{index}")
         self.assertEqual(f"next_value{index}".encode(),
                          client_store.compare_set(f"new_key{index}", f"new_value{index}", f"next_value{index}"))
     except Exception:
         messages.put('Caught exception: \n{}exiting process with exit code: {}'
                      .format(traceback.format_exc(), MultiProcessTestCase.TEST_ERROR_EXIT_CODE))
         sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE)
Ejemplo n.º 13
0
 def _nccl_init(self, nccl_addr, nccl_ip, nccl_port):
     self.nccl_ip, self.nccl_addr, self.nccl_port = nccl_ip, nccl_addr, nccl_port
     print('Rank {} calling init_process_group. Addr: {}'.format(self.rank, nccl_addr))
     # from https://github.com/pytorch/pytorch/blob/master/test/simulate_nccl_errors.py
     store = dist.TCPStore(self.nccl_ip, self.nccl_port, self.nb_learners, self.rank == 0)
     process_group = dist.ProcessGroupNCCL(store, self.rank, self.nb_learners)
     print('Rank {} initialized process group.'.format(self.rank))
     process_group.barrier()
     print('Rank {} process group barrier finished.'.format(self.rank))
     self.process_group = process_group
     # set optimizer process_group
     self.optimizer.set_process_group(self.process_group)
Ejemplo n.º 14
0
 def init_process(self, rank, args, kwargs):
     dist.init_process_group(self.backend,
                             rank=rank,
                             world_size=self.world_size)
     store = dist.TCPStore("127.0.0.1", 1234, 2, rank == 0,
                           timedelta(seconds=30))
     trainer = self.cls(*args, **kwargs)
     store.set("shared", str(json.dumps(RecDict(trainer.state.shared))))
     trainer._store = store
     trainer._dist = dist
     trainer._rank = rank
     trainer.run()
Ejemplo n.º 15
0
def create_tcp_store(addr):
    """
    Creates a TCP store. Retries if the chosen port is already in use.
    """
    while True:
        try:
            port = common.find_free_port()
            return c10d.TCPStore(addr, port, True)
        except RuntimeError as error:
            if str(error) == "Address already in use":
                continue
            raise
Ejemplo n.º 16
0
 def test_gloo_backend(self):
     store = c10d.TCPStore('localhost', self.port, self.is_master)
     options = c10d.ProcessGroupGloo.Options()
     options.devices = [
         c10d.ProcessGroupGloo.create_tcp_device(interface="lo")
     ]
     process_group = c10d.ProcessGroupGloo(store, self.rank,
                                           self.world_size, options)
     gpus = gpus_for_rank(self.world_size)[self.rank]
     self._test_ddp_with_process_group(process_group, gpus)
     self._test_ddp_with_process_group(
         process_group,
         list(map(lambda i: torch.device('cuda:' + str(i)), gpus)))
Ejemplo n.º 17
0
def init_distrib_slurm(
    backend: str = "nccl", ) -> Tuple[int, torch.distributed.TCPStore]:
    r"""Initializes torch.distributed by parsing environment variables set
        by SLURM when `srun` is used or by parsing environment variables set
        by torch.distributed.launch

    :param backend: Which torch.distributed backend to use

    :returns: Tuple of the local_rank (aka which GPU to use for this process)
        and the TCPStore used for the rendezvous
    """
    assert (torch.distributed.is_available()
            ), "torch.distributed must be available"

    if "GLOO_SOCKET_IFNAME" not in os.environ:
        os.environ["GLOO_SOCKET_IFNAME"] = get_ifname()

    if "NCCL_SOCKET_IFNAME" not in os.environ:
        os.environ["NCCL_SOCKET_IFNAME"] = get_ifname()

    master_port = int(os.environ.get("MASTER_PORT", DEFAULT_PORT))
    master_addr = os.environ.get("MASTER_ADDR", DEFAULT_MASTER_ADDR)

    # Check to see if we should parse from torch.distributed.launch
    if os.environ.get("LOCAL_RANK", None) is not None:
        local_rank = int(os.environ["LOCAL_RANK"])
        world_rank = int(os.environ["RANK"])
        world_size = int(os.environ["WORLD_SIZE"])
    # Else parse from SLURM is using SLURM
    elif os.environ.get("SLURM_JOBID", None) is not None:
        local_rank = int(os.environ["SLURM_LOCALID"])
        world_rank = int(os.environ["SLURM_PROCID"])
        world_size = int(os.environ["SLURM_NTASKS"])
    # Otherwise setup for just 1 process, this is nice for testing
    else:
        local_rank = 0
        world_rank = 0
        world_size = 1
    # Default port to initialized the TCP store on
    master_port += 3
    # Default address of world rank 0
    master_addr = "127.0.0.3"
    tcp_store = distrib.TCPStore(master_addr, master_port, world_size,
                                 world_rank == 0)
    distrib.init_process_group(backend,
                               store=tcp_store,
                               rank=world_rank,
                               world_size=world_size)

    return local_rank, tcp_store
Ejemplo n.º 18
0
 def init_process(self, rank, args, kwargs, mode):
     dist.init_process_group(self.backend,
                             rank=rank,
                             world_size=self.world_size)
     store = dist.TCPStore(os.environ['STORE_ADDR'],
                           int(os.environ['STORE_PORT']), len(self.world),
                           rank == 0, timedelta(seconds=30))
     self.__trainer._dist = dist
     self.__trainer._rank = rank
     self.__trainer._mode = mode
     self.__trainer._world_size = len(self.world)
     store.set("test",
               str(json.dumps(RecDict(self.__trainer._metrics.test))))
     self.__trainer._store = store
     self.__trainer.run(*args, **kwargs)
Ejemplo n.º 19
0
    def test_sync_params_with_buffers(self):
        # Set up process group.
        store = c10d.TCPStore('localhost', self.port, self.is_master)
        options = c10d.ProcessGroupGloo.Options()
        options.devices = [
            c10d.ProcessGroupGloo.create_tcp_device(interface="lo")
        ]
        process_group = c10d.ProcessGroupGloo(store, self.rank,
                                              self.world_size, options)

        devices = gpus_for_rank(self.world_size)[self.rank]
        target = torch.arange(10, dtype=torch.float64,
                              device='cuda:0').chunk(5)
        parameter_data = [target]
        parameter_data += [
            torch.zeros(10, device=torch.device('cuda', d)).chunk(5)
            for d in devices[1:]
        ]

        # sync_params should do a dist_broadcast for buffers, so we only populate the master buffers and
        # then check that other processes' tensors end up matching.

        if self.is_master:
            buffer_data = [target]
            buffer_data += [
                torch.zeros(10, device=torch.device('cuda', d)).chunk(5)
                for d in devices[1:]
            ]
        else:
            buffer_data = [
                torch.zeros(10, device=torch.device('cuda', d)).chunk(5)
                for d in devices
            ]

        c10d._sync_params(process_group,
                          parameter_data=parameter_data,
                          buffer_data=buffer_data,
                          devices=devices,
                          broadcast_bucket_size=10,
                          broadcast_buffers=True)

        for device_data in parameter_data:
            for i, parameter in enumerate(device_data):
                self.assertEqual(parameter, target[i])

        for device_data in buffer_data:
            for i, buffer in enumerate(device_data):
                self.assertEqual(buffer, target[i])
Ejemplo n.º 20
0
def init_distrib_slurm(
    backend: str = "nccl",
) -> Tuple[int, torch.distributed.TCPStore]:  # type: ignore
    r"""Initializes torch.distributed by parsing environment variables set
        by SLURM when ``srun`` is used or by parsing environment variables set
        by torch.distributed.launch

    :param backend: Which torch.distributed backend to use

    :returns: Tuple of the local_rank (aka which GPU to use for this process)
        and the TCPStore used for the rendezvous
    """
    assert (torch.distributed.is_available()
            ), "torch.distributed must be available"

    if "GLOO_SOCKET_IFNAME" not in os.environ:
        os.environ["GLOO_SOCKET_IFNAME"] = get_ifname()

    if "NCCL_SOCKET_IFNAME" not in os.environ:
        os.environ["NCCL_SOCKET_IFNAME"] = get_ifname()

    local_rank, world_rank, world_size = get_distrib_size()

    master_addr = os.environ.get("MASTER_ADDR", DEFAULT_MASTER_ADDR)
    master_port = int(os.environ.get("MASTER_PORT", DEFAULT_PORT))
    if SLURM_JOBID is not None:
        master_port += int(SLURM_JOBID) % int(
            os.environ.get("MASTER_PORT_RANGE", DEFAULT_PORT_RANGE))
    if MULTI_PROC_OFFSET is not None:
        master_port += int(MULTI_PROC_OFFSET)

    tcp_store = distrib.TCPStore(  # type: ignore
        master_addr, master_port, world_size, world_rank == 0)
    distrib.init_process_group(backend,
                               store=tcp_store,
                               rank=world_rank,
                               world_size=world_size)

    return local_rank, tcp_store
Ejemplo n.º 21
0
def init_distrib_slurm(backend="nccl"):
    if "GLOO_SOCKET_IFNAME" not in os.environ:
        os.environ["GLOO_SOCKET_IFNAME"] = get_ifname()

    if "NCCL_SOCKET_IFNAME" not in os.environ:
        os.environ["NCCL_SOCKET_IFNAME"] = get_ifname()

    master_port = int(os.environ.get("MASTER_PORT", 8738))
    master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
    local_rank = int(
        os.environ.get("LOCAL_RANK", os.environ.get("SLURM_LOCALID", 0)))
    world_rank = int(os.environ.get("RANK", os.environ.get("SLURM_PROCID", 0)))
    world_size = int(
        os.environ.get("WORLD_SIZE", os.environ.get("SLURM_NTASKS", 1)))

    tcp_store = distrib.TCPStore(master_addr, master_port, world_size,
                                 world_rank == 0)
    distrib.init_process_group(backend,
                               store=tcp_store,
                               rank=world_rank,
                               world_size=world_size)

    return local_rank, tcp_store
Ejemplo n.º 22
0
 def test_gloo_backend(self):
     store = c10d.TCPStore('localhost', self.port, self.is_master)
     options = c10d.ProcessGroupGloo.Options()
     options.devices = [c10d.ProcessGroupGloo.create_tcp_device(interface="lo")]
     process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size, options)
     self._test_ddp_with_process_group(process_group)
Ejemplo n.º 23
0
 def test_nccl_backend(self):
     store = c10d.TCPStore('localhost', self.port, self.is_master)
     process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
     self._test_ddp_with_process_group(process_group)
Ejemplo n.º 24
0
 def setUp(self):
     addr = 'localhost'
     port = common.find_free_port()
     self.tcpstore = c10d.TCPStore(addr, port, True)
     self.prefix = "test_prefix"
     self.tcpstore.set_timeout(timedelta(seconds=300))
Ejemplo n.º 25
0
 def _create_store(self):
     addr = 'localhost'
     port = common.find_free_port()
     store = c10d.TCPStore(addr, port, True)
     store.set_timeout(timedelta(seconds=300))
     return store
Ejemplo n.º 26
0
def create_c10d_store(
    is_server: bool,
    server_addr: str,
    server_port: int = -1,
    world_size: int = 1,
    timeout: float = (60 * 10),  # 10 min
    wait_for_workers: bool = True,
    retries=3,
):
    if server_port == -1 and world_size > 1:
        raise ValueError(
            f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}"
        )

    if server_port != -1:
        log.info(f"sever_port: {server_port}, specified, ignoring retries")

    # only retry when server_port is NOT static
    attempt = retries if server_port == -1 else 1
    while True:
        if server_port != -1:
            port = server_port
        else:
            port = get_free_port()

        log.info(
            f"Creating c10d store on {server_addr}:{port}\n"
            f"  world_size  : {world_size}\n"
            f"  is_server   : {is_server}\n"
            f"  timeout(sec): {timeout}\n"
        )

        try:
            store = dist.TCPStore(
                host_name=server_addr,
                port=port,
                world_size=world_size,
                is_master=is_server,
                timeout=datetime.timedelta(seconds=timeout),
                wait_for_workers=wait_for_workers,
            )
            # skips full rank check when we don't have to wait for all workers
            if wait_for_workers:
                _check_full_rank(store, world_size)
            log.info("Successfully created c10d store")
            return store
        except RuntimeError as e:
            # this is brittle, but the underlying exception type is not properly pybinded
            # so we parse the error msg for now, interestingly this is how torch itself
            # detects timeouts and port conflicts in their own unittests
            # see - caffe2/torch/testing/_internal/common_utils.py
            # TODO properly map the exceptions in pybind (c10d/init.cpp)
            if str(e) == _ADDRESS_IN_USE:  # this will only happen on the server
                if attempt < retries:
                    log.warning(
                        f"port: {port} already in use, attempt: [{attempt}/{retries}]"
                    )
                    attempt += 1
                else:
                    raise RuntimeError(
                        f"on {server_addr}, port: {port} already in use"
                    ) from e
            else:
                raise
Ejemplo n.º 27
0
 def test_address_already_in_use(self):
     with self.assertRaisesRegex(RuntimeError, "^Address already in use$"):
         addr = 'localhost'
         port = common.find_free_port()
         store1 = c10d.TCPStore(addr, port, True)
         store2 = c10d.TCPStore(addr, port, True)
Ejemplo n.º 28
0
    parser = argparse.ArgumentParser(
        description='Simple script to simulate NCCL errors. The script is '
        'supposed to be run on multiple different nodes simultaneously with '
        'appropriate rank and world_size. The script run an allreduce() on '
        'the rank 0 node and aborts all the other nodes to simulate an error '
        'in NCCL')
    parser.add_argument('addr',
                        help='address of the master node to connect to.')
    parser.add_argument('port', help='port of the master node to connect to.')
    parser.add_argument('rank', help='rank of this node')
    parser.add_argument('world_size', help='number of nodes in process group')
    args = parser.parse_args()
    rank = int(args.rank)
    world_size = int(args.world_size)
    port = int(args.port)

    store = c10d.TCPStore(args.addr, port, world_size, rank == 0)
    process_group = c10d.ProcessGroupNCCL(store, rank, world_size)
    logging.info('Running first allreduce')
    process_group.allreduce(torch.rand(10).cuda(rank)).wait()
    if rank == 0:
        logging.info('Running second allreduce only on rank 0')
        work = process_group.allreduce(torch.rand(10).cuda(rank))
        logging.info('Waiting for allreduce to complete...')
        work.wait()
        logging.info('Second allreduce successful: {}'.format(
            work.is_success()))
    else:
        logging.info('Aborting all other ranks.')
        os.abort()
Ejemplo n.º 29
0
def _worker_fn(world_rank: int, world_size: int, port: int,
               unused_params: bool):
    device = (torch.device("cuda")
              if torch.cuda.is_available() else torch.device("cpu"))
    tcp_store = distrib.TCPStore(  # type: ignore
        "127.0.0.1", port, world_size, world_rank == 0)
    distrib.init_process_group("gloo",
                               store=tcp_store,
                               rank=world_rank,
                               world_size=world_size)

    config = get_config("habitat_baselines/config/test/ppo_pointnav_test.yaml")
    obs_space = gym.spaces.Dict({
        IntegratedPointGoalGPSAndCompassSensor.cls_uuid:
        gym.spaces.Box(
            low=np.finfo(np.float32).min,
            high=np.finfo(np.float32).max,
            shape=(2, ),
            dtype=np.float32,
        )
    })
    action_space = ActionSpace({"move": EmptySpace()})
    actor_critic = PointNavBaselinePolicy.from_config(config, obs_space,
                                                      action_space)
    # This use adds some arbitrary parameters that aren't part of the computation
    # graph, so they will mess up DDP if they aren't correctly ignored by it
    if unused_params:
        actor_critic.unused = nn.Linear(64, 64)

    actor_critic.to(device=device)
    ppo_cfg = config.RL.PPO
    agent = DDPPO(
        actor_critic=actor_critic,
        clip_param=ppo_cfg.clip_param,
        ppo_epoch=ppo_cfg.ppo_epoch,
        num_mini_batch=ppo_cfg.num_mini_batch,
        value_loss_coef=ppo_cfg.value_loss_coef,
        entropy_coef=ppo_cfg.entropy_coef,
        lr=ppo_cfg.lr,
        eps=ppo_cfg.eps,
        max_grad_norm=ppo_cfg.max_grad_norm,
        use_normalized_advantage=ppo_cfg.use_normalized_advantage,
    )
    agent.init_distributed()
    rollouts = RolloutStorage(
        ppo_cfg.num_steps,
        2,
        obs_space,
        action_space,
        ppo_cfg.hidden_size,
        num_recurrent_layers=actor_critic.net.num_recurrent_layers,
        is_double_buffered=False,
    )
    rollouts.to(device)

    for k, v in rollouts.buffers["observations"].items():
        rollouts.buffers["observations"][k] = torch.randn_like(v)

    # Add two steps so batching works
    rollouts.advance_rollout()
    rollouts.advance_rollout()

    # Get a single batch
    batch = next(rollouts.recurrent_generator(rollouts.buffers["returns"], 1))

    # Call eval actions through the internal wrapper that is used in
    # agent.update
    value, action_log_probs, dist_entropy, _ = agent._evaluate_actions(
        batch["observations"],
        batch["recurrent_hidden_states"],
        batch["prev_actions"],
        batch["masks"],
        batch["actions"],
    )
    # Backprop on things
    (value.mean() + action_log_probs.mean() + dist_entropy.mean()).backward()

    # Make sure all ranks have very similar parameters
    for param in actor_critic.parameters():
        if param.grad is not None:
            grads = [param.grad.detach().clone() for _ in range(world_size)]
            distrib.all_gather(grads, grads[world_rank])

            for i in range(world_size):
                assert torch.isclose(grads[i], grads[world_rank]).all()
def main(argv: List[str]) -> None:
    """Script entry point.

  Parameters
  ----------
  argv: list[str]
    List of CLI arguments.

  Returns
  -------
  None
  """
    # Parse CLI arguments.
    args = parse_args(argv=argv)

    # `args.batch_size` validation.
    lmp.util.validate.raise_if_wrong_ordered(
        vals=[1, args.batch_size], val_names=['1', 'args.batch_size'])
    # `args.first_ckpt` validation.
    lmp.util.validate.raise_if_wrong_ordered(
        vals=[-1, args.first_ckpt], val_names=['-1', 'args.first_ckpt'])
    # `args.last_ckpt` validation.
    lmp.util.validate.raise_if_wrong_ordered(
        vals=[-1, args.last_ckpt], val_names=['-1', 'args.last_ckpt'])
    # `args.n_worker` validation.
    lmp.util.validate.raise_if_wrong_ordered(
        vals=[0, args.n_worker, len(os.sched_getaffinity(0))],
        val_names=['0', 'args.n_worker', 'number of available CPUs'],
    )
    lmp.util.validate.raise_if_wrong_ordered(
        vals=[args.n_worker, args.batch_size],
        val_names=['args.n_worker', 'args.batch_size'],
    )

    # We use TCP to perform RPC.  Timeout is set to 5 minutes.
    store = dist.TCPStore(
        is_master=args.rank == HOST_RANK,
        host_name=args.host_name,
        port=args.host_port,
        timeout=timedelta(minutes=5),
        world_size=args.world_size,
    )

    # Use NCCL backend to perform CUDA collectives.
    dist.init_process_group(
        backend=dist.Backend.NCCL,
        store=store,
        rank=args.rank,
        timeout=timedelta(minutes=5),
        world_size=args.world_size,
    )

    # Sync arguments.
    dist_args_k = [
        'host_name', 'host_port', 'local_rank', 'rank', 'world_size'
    ]
    for k in args.__dict__.keys():
        if k in dist_args_k:
            continue

        # Host broadcast arguments.
        if args.rank == HOST_RANK:
            store.set(k, str(args.__dict__[k]))
        # Non-host receive host arguments.
        else:
            v = store.get(k)
            if isinstance(args.__dict__[k], str):
                args.__dict__[k] = v.decode('utf-8')
            else:
                args.__dict__[k] = type(args.__dict__[k])(v)

    # Set random seed for reproducibility.  Note that each process use different seed to get different slice of batch.
    lmp.util.rand.set_seed(seed=args.seed + args.rank)

    # Get model running device.
    device = torch.device('cpu')
    if torch.cuda.is_available():
        device = torch.device(f'cuda:{args.local_rank}')

    # Load pre-trained model configuration.
    model_cfg = lmp.util.cfg.load(exp_name=args.exp_name)

    # Load pre-trained tokenizer instance.
    tknzr = lmp.util.tknzr.load(exp_name=model_cfg.tknzr_exp_name)

    # Get dataset instance and convert samples to tensor.
    if args.is_dset_in_memory:
        dset: torch.utils.data.Dataset = lmp.util.dset.FastTensorDset(
            dset=lmp.util.dset.load(**args.__dict__),
            max_seq_len=model_cfg.max_seq_len,
            tknzr=tknzr,
        )
    else:
        dset = lmp.util.dset.SlowTensorDset(
            dset=lmp.util.dset.load(**args.__dict__),
            max_seq_len=model_cfg.max_seq_len,
            tknzr=tknzr,
        )

    dset_size = len(dset)

    # Mini-batch sampler.  Each process will get batches exclusive to itself.
    dist_sampler = torch.utils.data.distributed.DistributedSampler(
        num_replicas=args.world_size,
        rank=args.rank,
        dataset=dset,
        shuffle=False,
    )

    # Mini-batch distributed random sampler.  Only when `args.n_worker > 0` we set `persisten_worker = True`.  We set
    # `pin_memory = True` to speed up process (which only speed up a few seconds).
    data_loader = torch.utils.data.DataLoader(
        batch_size=args.batch_size // args.world_size,
        dataset=dset,
        num_workers=args.n_worker,
        persistent_workers=bool(args.n_worker != 0),
        pin_memory=True,
        sampler=dist_sampler,
    )

    # Get tensorboard logger instance.  Only main process need to log performance.
    if args.rank == HOST_RANK:
        writer = lmp.util.log.get_tb_logger(exp_name=args.exp_name)
    else:
        writer = None

    # Evaluate checkpoints within ranges.
    for ckpt in lmp.util.model.list_ckpts(exp_name=args.exp_name,
                                          first_ckpt=args.first_ckpt,
                                          last_ckpt=args.last_ckpt):
        # Load pre-trained model instance.
        model = lmp.util.model.load(ckpt=ckpt, exp_name=args.exp_name)

        # Set model to evaluation model.  This turn off dropout layers in model.
        model = model.eval()

        # Move model to running device.
        model = model.to(device)

        # Create DDP model.
        dpp_model = torch.nn.parallel.DistributedDataParallel(model)

        # Processes can have unevenly distributed number of batch.  Thus one must use `ddp_model.join()` to avoid dead lock.
        with dpp_model.join():
            # Record average perplexity.
            avg_ppl = 0.0
            for batch_tkids in tqdm(data_loader):
                # Encode text into token ids.  We convert token ids into tensor and move to the same running device as model.
                batch_tkids = batch_tkids.to(device)

                # Format batch token ids to satisfy language model training format.
                batch_cur_tkids = batch_tkids[..., :-1]
                batch_next_tkids = batch_tkids[..., 1:]

                # Loop over token ids to get next token id prediction probability distribution.
                batch_prev_states = None
                batch_tkids_pd = []
                for i in range(batch_cur_tkids.size(1)):
                    batch_next_tkids_pd, batch_prev_states = model.pred(
                        batch_cur_tkids=batch_cur_tkids[:, i],
                        batch_prev_states=batch_prev_states,
                    )

                    # Collect prediction probability distribution.
                    batch_tkids_pd.append(batch_next_tkids_pd)

                # Calculate perplexity.
                batch_ppl = lmp.util.metric.ppl(batch_tkids=batch_next_tkids,
                                                batch_tkids_pd=torch.stack(
                                                    batch_tkids_pd, dim=1))

                # Sum `batch_ppl` from each process.
                dist.all_reduce(batch_ppl, op=dist.ReduceOp.SUM)

                # Accumulate average perplexity.
                avg_ppl += (batch_ppl / dset_size).sum().item()

        # Log average perplexity on dataset to CLI and tensorboard.  Only main process need to log performance.
        if args.rank == HOST_RANK:
            writer.add_scalar(f'ppl/{args.dset_name}/{args.ver}', avg_ppl,
                              ckpt)
        print(f'checkpoint: {ckpt}, avg ppl: {avg_ppl}')

    # Free memory.  This is only need for unit test.
    del args
    del avg_ppl
    del batch_cur_tkids
    del batch_next_tkids
    del batch_next_tkids_pd
    del batch_ppl
    del batch_prev_states
    del batch_tkids
    del batch_tkids_pd
    del ckpt
    del data_loader
    del device
    del dset
    del dset_size
    del model
    del model_cfg
    del tknzr
    del writer
    torch.cuda.empty_cache()
    gc.collect()