Exemplo n.º 1
0
def get_node_count():
    # noinspection PyBroadException
    try:
        mpi_world_rank = int(
            os.environ.get('OMPI_COMM_WORLD_NODE_RANK',
                           os.environ.get('PMI_RANK')))
        return mpi_world_rank
    except Exception:
        pass

    # noinspection PyBroadException
    try:
        mpi_rank = int(
            os.environ.get('OMPI_COMM_WORLD_RANK',
                           os.environ.get('SLURM_JOB_NUM_NODES')))
        return mpi_rank
    except Exception:
        pass

    # check if we have pyTorch node/worker ID (only if torch was already imported)
    if 'torch' in sys.modules:
        # noinspection PyBroadException
        try:
            from torch.utils.data.dataloader import get_worker_info  # noqa
            worker_info = get_worker_info()
            if worker_info:
                return int(worker_info.num_workers)
        except Exception:
            pass

    return None
Exemplo n.º 2
0
def get_node_id(default=0):
    node_id = NODE_ID_ENV_VAR.get()  # noqa: F405

    # noinspection PyBroadException
    try:
        mpi_world_rank = int(
            os.environ.get('OMPI_COMM_WORLD_NODE_RANK',
                           os.environ.get('PMI_RANK')))
    except Exception:
        mpi_world_rank = None

    # noinspection PyBroadException
    try:
        mpi_rank = int(
            os.environ.get(
                'OMPI_COMM_WORLD_RANK',
                os.environ.get('SLURM_PROCID',
                               os.environ.get('SLURM_NODEID'))))
    except Exception:
        mpi_rank = None

    # if we have no node_id, use the mpi rank
    if node_id is None and (mpi_world_rank is not None
                            or mpi_rank is not None):
        node_id = mpi_world_rank if mpi_world_rank is not None else mpi_rank

    # if node is is till None, use the default
    if node_id is None:
        node_id = default

    torch_rank = None
    # check if we have pyTorch node/worker ID (only if torch was already imported)
    if 'torch' in sys.modules:
        # noinspection PyBroadException
        try:
            from torch.utils.data.dataloader import get_worker_info  # noqa
            worker_info = get_worker_info()
            if not worker_info:
                torch_rank = None
            else:
                w_id = worker_info.id
                # noinspection PyBroadException
                try:
                    torch_rank = int(w_id)
                except Exception:
                    # guess a number based on wid hopefully unique value
                    import hashlib
                    h = hashlib.md5()
                    h.update(str(w_id).encode('utf-8'))
                    torch_rank = int(h.hexdigest(), 16)
        except Exception:
            torch_rank = None

    # if we also have a torch rank add it to the node rank
    if torch_rank is not None:
        # Since we dont know the world rank, we assume it is not bigger than 10k
        node_id = (10000 * node_id) + torch_rank

    return node_id
Exemplo n.º 3
0
 def get_data_with_worker_info(self):
     worker_info = get_worker_info()
     if worker_info is None:
         return self.data
     else:
         num_workers = worker_info.num_workers
         worker_id = worker_info.id
         if num_workers > self.num_images:
             raise ValueError(
                 "Cannot have a number of worker larger than the number of images,"
                 "it will results in duplicates.")
         images_per_worker = int(
             math.ceil(self.num_images / float(num_workers)))
         start = worker_id * images_per_worker
         end = min(start + images_per_worker, self.num_images)
         return self.data.iloc[start:end]