コード例 #1
0
    def master_port(self):
        # -----------------------
        # SLURM JOB = PORT number
        # -----------------------
        # this way every process knows what port to use
        try:
            # use the last 4 numbers in the job id as the id
            default_port = os.environ["SLURM_JOB_ID"]
            default_port = default_port[-4:]

            # all ports should be in the 10k+ range
            default_port = int(default_port) + 15000

        except Exception:
            default_port = 12910

        # -----------------------
        # PORT NUMBER = MASTER_PORT
        # -----------------------
        # in case the user passed it in
        try:
            default_port = os.environ["MASTER_PORT"]
        except Exception:
            os.environ["MASTER_PORT"] = str(default_port)

        log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

        return default_port
コード例 #2
0
    def __init__(
        self,
        metric_name: str,
        reduce_group: Any = group.WORLD,
        reduce_op: Any = ReduceOp.SUM,
        **kwargs,
    ):
        """
        Args:
            metric_name: the metric name to import and compute from scikit-learn.metrics
            reduce_group: the process group for DDP reduces (only needed for DDP training).
                Defaults to all processes (world)
            reduce_op: the operation to perform during reduction within DDP (only needed for DDP training).
                Defaults to sum.
            **kwargs: additonal keyword arguments (will be forwarded to metric call)
        """
        super().__init__(name=metric_name,
                         reduce_group=reduce_group,
                         reduce_op=reduce_op)

        self.metric_kwargs = kwargs
        lightning_logger.debug(
            f'Metric {self.__class__.__name__} is using Sklearn as backend, meaning that'
            ' every metric call will cause a GPU synchronization, which may slow down your code'
        )
コード例 #3
0
 def master_address(self):
     if "MASTER_ADDR" not in os.environ:
         rank_zero_warn(
             "MASTER_ADDR environment variable is not defined. Set as localhost"
         )
         os.environ["MASTER_ADDR"] = "127.0.0.1"
     log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
     master_address = os.environ.get('MASTER_ADDR')
     return master_address
コード例 #4
0
    def master_port(self):
        if "MASTER_PORT" not in os.environ:
            rank_zero_warn(
                "MASTER_PORT environment variable is not defined. Set as 12910"
            )
            os.environ["MASTER_PORT"] = "12910"
        log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

        port = os.environ.get('MASTER_PORT')
        return port
コード例 #5
0
    def master_address(self):
        # figure out the root node addr
        try:
            root_node = os.environ["SLURM_NODELIST"].split(" ")[0]
        except Exception:
            root_node = "127.0.0.1"

        root_node = self._resolve_root_node_address(root_node)
        os.environ["MASTER_ADDR"] = root_node
        log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
        return root_node
コード例 #6
0
 def __init__(self) -> None:
     super().__init__()
     # TODO: remove in 1.7
     if hasattr(self, "is_using_lsf") and callable(self.is_using_lsf):
         rank_zero_deprecation(
             f"`{self.__class__.__name__}.is_using_lsf` has been deprecated in v1.6 and will be removed in v1.7."
             " Implement the static method `detect()` instead (do not forget to add the `@staticmethod` decorator)."
         )
     self._main_address = self._get_main_address()
     self._main_port = self._get_main_port()
     log.debug(f"MASTER_ADDR: {self._main_address}")
     log.debug(f"MASTER_PORT: {self._main_port}")
コード例 #7
0
    def master_address(self):
        # figure out the root node addr
        slurm_nodelist = os.environ.get("SLURM_NODELIST")
        if slurm_nodelist:
            root_node = slurm_nodelist.split(" ")[0].split(",")[0]
        else:
            root_node = "127.0.0.1"

        root_node = self._resolve_root_node_address(root_node)
        os.environ["MASTER_ADDR"] = root_node
        log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
        return root_node
コード例 #8
0
    def _get_main_port() -> int:
        """A helper function for accessing the main port.

        Uses the LSF job ID so all ranks can compute the main port.
        """
        # check for user-specified main port
        if "MASTER_PORT" in os.environ:
            log.debug(f"Using externally specified main port: {os.environ['MASTER_PORT']}")
            return int(os.environ["MASTER_PORT"])
        if "LSB_JOBID" in os.environ:
            port = int(os.environ["LSB_JOBID"])
            # all ports should be in the 10k+ range
            port = port % 1000 + 10000
            log.debug(f"calculated LSF main port: {port}")
            return port
        raise ValueError("Could not find job id in environment variable LSB_JOBID")
コード例 #9
0
def cloud_open(path: pathlike, mode: str, newline: str = None):
    if platform.system() == "Windows":
        log.debug(
            "gfile does not handle newlines correctly on windows so remote files are not"
            " supported falling back to normal local file open.")
        return open(path, mode, newline=newline)
    if not modern_gfile():
        log.debug(
            "tenosrboard.compat gfile does not work on older versions "
            "of tensorboard for remote files, using normal local file open.")
        return open(path, mode, newline=newline)
    try:
        return gfile.GFile(path, mode)
    except NotImplementedError as e:
        # minimal dependencies are installed and only local files will work
        return open(path, mode, newline=newline)
コード例 #10
0
    def connect_ddp(self, global_rank: int, world_size: int) -> None:
        """
        Sets up environment variables necessary for pytorch distributed communications
        based on slurm environment.
        """
        # use slurm job id for the port number
        # guarantees unique ports across jobs from same grid search
        try:
            # use the last 4 numbers in the job id as the id
            default_port = os.environ["SLURM_JOB_ID"]
            default_port = default_port[-4:]

            # all ports should be in the 10k+ range
            default_port = int(default_port) + 15000

        except Exception:
            default_port = 12910

        # if user gave a port number, use that one instead
        try:
            default_port = os.environ["MASTER_PORT"]
        except Exception:
            os.environ["MASTER_PORT"] = str(default_port)
        log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

        # figure out the root node addr
        try:
            root_node = os.environ["SLURM_NODELIST"].split(" ")[0].split(
                ",")[0]
        except Exception:
            root_node = "127.0.0.1"

        root_node = self.trainer.slurm_connector.resolve_root_node_address(
            root_node)
        os.environ["MASTER_ADDR"] = root_node
        log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")

        torch_backend = "nccl" if self.trainer.on_gpu else "gloo"

        if not torch.distributed.is_initialized():
            log.info(
                f"initializing ddp (SLURM): GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}"
            )
            torch_distrib.init_process_group(torch_backend,
                                             rank=global_rank,
                                             world_size=world_size)
コード例 #11
0
    def __init__(self):
        self._master_address = self._get_master_address()
        self._master_port = self._get_master_port()
        self._local_rank = self._get_local_rank()
        self._global_rank = self._get_global_rank()
        self._world_size = self._get_world_size()
        self._node_rank = self._get_node_rank()

        # set environment variables needed for initializing torch distributed process group
        os.environ["MASTER_ADDR"] = str(self._master_address)
        log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
        os.environ["MASTER_PORT"] = str(self._master_port)
        log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

        self._rep = ",".join('%s=%s' % (s, getattr(self, "_" + s))
                             for s in ('master_address', 'master_port',
                                       'world_size', 'local_rank', 'node_rank',
                                       'global_rank'))
コード例 #12
0
 def _get_master_port():
     """
     A helper function for accessing the master port.
     Uses the LSF job ID so all ranks can compute the master port.
     """
     # check for user-specified master port
     port = os.environ.get("MASTER_PORT")
     if not port:
         jobid = os.environ.get("LSB_JOBID")
         if not jobid:
             raise ValueError("Could not find job id in environment variable LSB_JOBID")
         port = int(jobid)
         # all ports should be in the 10k+ range
         port = int(port) % 1000 + 10000
         log.debug(f"calculated LSF master port: {port}")
     else:
         log.debug(f"using externally specified master port: {port}")
     return int(port)
コード例 #13
0
    def connect_torchelastic(self, global_rank: int, world_size: int) -> None:
        """
        Override to define your custom way of setting up a distributed environment.

        Lightning's implementation uses env:// init by default and sets the first node as root
        for SLURM managed cluster.

        Args:
            global_rank: The global process idx.
            world_size: Number of GPUs being use across all nodes. (num_nodes * num_gpus).
        """

        if "MASTER_ADDR" not in os.environ:
            rank_zero_warn(
                "MASTER_ADDR environment variable is not defined. Set as localhost"
            )
            os.environ["MASTER_ADDR"] = "127.0.0.1"
        log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")

        if "MASTER_PORT" not in os.environ:
            rank_zero_warn(
                "MASTER_PORT environment variable is not defined. Set as 12910"
            )
            os.environ["MASTER_PORT"] = "12910"
        log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")

        if "WORLD_SIZE" in os.environ and int(
                os.environ["WORLD_SIZE"]) != world_size:
            rank_zero_warn(
                f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) "
                f"is not equal to the computed world size ({world_size}). Ignored."
            )

        torch_backend = "nccl" if self.trainer.on_gpu else "gloo"

        if not torch.distributed.is_initialized():
            log.info(
                f"initializing ddp: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}"
            )
            torch_distrib.init_process_group(torch_backend,
                                             rank=global_rank,
                                             world_size=world_size)
コード例 #14
0
    def set_nvidia_flags(self, is_slurm_managing_tasks,
                         data_parallel_device_ids):
        if data_parallel_device_ids is None:
            return

        # set the correct cuda visible devices (using pci order)
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

        # when slurm is managing the task it sets the visible devices
        if not is_slurm_managing_tasks:
            if isinstance(data_parallel_device_ids, int):
                id_str = ','.join(
                    str(x) for x in list(range(data_parallel_device_ids)))
                os.environ["CUDA_VISIBLE_DEVICES"] = id_str
            else:
                gpu_str = ','.join([str(x) for x in data_parallel_device_ids])
                os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str

        log.debug(
            f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')
コード例 #15
0
    def init_amp(self):
        if NATIVE_AMP_AVALAIBLE:
            log.debug("`amp_level` has been deprecated since v0.7.4 (native amp does not require it)")

        assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'

        if self.use_amp and NATIVE_AMP_AVALAIBLE:
            log.info('Using native 16bit precision.')
            return

        if self.use_amp and not APEX_AVAILABLE:  # pragma: no-cover
            raise ModuleNotFoundError(
                "You set `use_amp=True` but do not have apex installed."
                " Install apex first using this guide: https://github.com/NVIDIA/apex#linux"
                " and rerun with `use_amp=True`."
                " This run will NOT use 16 bit precision."
            )

        if self.use_amp:
            log.info('Using APEX 16bit precision.')
コード例 #16
0
    def _get_master_port(self):
        """A helper for getting the master port

        Use the LSF job ID so all ranks can compute the master port
        """
        # check for user-specified master port
        port = os.environ.get("MASTER_PORT")
        if not port:
            var = "LSB_JOBID"
            jobid = os.environ.get(var)
            if not jobid:
                raise ValueError(
                    "Could not find job id -- expected in environment variable %s"
                    % var)
            else:
                port = int(jobid)
                # all ports should be in the 10k+ range
                port = int(port) % 1000 + 10000
            log.debug("calculated master port")
        else:
            log.debug("using externally specified master port")
        return port
コード例 #17
0
 def _del_model(self, filepath: str):
     if self._fs.exists(filepath):
         self._fs.rm(filepath)
         log.debug(f"Removed checkpoint: {filepath}")
コード例 #18
0
 def set_global_rank(self, rank: int) -> None:
     log.debug(
         "LSFEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored."
     )
コード例 #19
0
 def set_world_size(self, size: int) -> None:
     log.debug(
         "LSFEnvironment.set_world_size was called, but setting world size is not allowed. Ignored."
     )
コード例 #20
0
 def __init__(self):
     self._main_address = self._get_main_address()
     self._main_port = self._get_main_port()
     log.debug(f"MASTER_ADDR: {self._main_address}")
     log.debug(f"MASTER_PORT: {self._main_port}")
コード例 #21
0
def convert(val: str) -> Union[int, float, bool, str]:
    try:
        return ast.literal_eval(val)
    except (ValueError, SyntaxError) as err:
        log.debug(err)
        return val
コード例 #22
0
def _debug(*args, **kwargs):
    log.debug(*args, **kwargs)
コード例 #23
0
 def _set_init_progress_group_env_vars(self) -> None:
     # set environment variables needed for initializing torch distributed process group
     os.environ["MASTER_ADDR"] = str(self._main_address)
     log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
     os.environ["MASTER_PORT"] = str(self._main_port)
     log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")