Exemplo n.º 1
0
    def test_logger_name(self):
        local_log = logging.get_logger()
        name_override_log = logging.get_logger("foobar")

        self.assertEqual("logging_test", log.name)
        self.assertEqual("logging_test", self.clazz_log.name)
        self.assertEqual("logging_test", local_log.name)
        self.assertEqual("foobar", name_override_log.name)
Exemplo n.º 2
0
import os
import sys
import uuid
from argparse import REMAINDER, ArgumentParser
from typing import Callable, List, Tuple, Union

import torch
from torch.distributed.argparse_util import check_env, env
from torch.distributed.elastic.multiprocessing import Std
from torch.distributed.elastic.multiprocessing.errors import record
from torch.distributed.elastic.rendezvous.utils import _parse_rendezvous_config
from torch.distributed.elastic.utils import macros
from torch.distributed.elastic.utils.logging import get_logger
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

log = get_logger()


def get_args_parser() -> ArgumentParser:
    """Helper function parsing the command line options."""

    parser = ArgumentParser(
        description="Torch Distributed Elastic Training Launcher")

    #
    # Worker/node size related arguments.
    #

    parser.add_argument(
        "--nnodes",
        action=env,
Exemplo n.º 3
0
import sys
import uuid
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Union, cast, Tuple

import torch.distributed.elastic.rendezvous.registry as rdzv_registry
from torch.distributed.elastic import events, metrics
from torch.distributed.elastic.agent.server.api import WorkerSpec, WorkerState  # type: ignore[import]
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent  # type: ignore[import]
from torch.distributed.elastic.multiprocessing import Std
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError, record
from torch.distributed.elastic.rendezvous import RendezvousParameters
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint
from torch.distributed.elastic.utils.logging import get_logger

logger = get_logger()


@dataclass
class LaunchConfig:
    """
    min_nodes: Minimum amount of nodes that the user function will
                     be launched on. Elastic agent ensures that the user
                     function start only when the min_nodes amount enters
                     the rendezvous.
    max_nodes: Maximum amount of nodes that the user function
                     will be launched on.
    nproc_per_node: On each node the elastic agent will launch
                          this amount of workers that will execute user
                          defined function.
    rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd).
Exemplo n.º 4
0
 def test_get_logger_custom_name(self):
     logger1 = get_logger("test.module")
     self.assertEqual("test.module", logger1.name)
Exemplo n.º 5
0
 def test_get_logger_none(self):
     logger1 = get_logger(None)
     self.assertEqual(__name__, logger1.name)
Exemplo n.º 6
0
 def test_get_logger(self):
     logger1 = get_logger()
     self.assertEqual(__name__, logger1.name)
Exemplo n.º 7
0
 def test_get_logger_different(self):
     logger1 = get_logger("name1")
     logger2 = get_logger("name2")
     self.assertNotEqual(logger1.name, logger2.name)
Exemplo n.º 8
0
 def setUp(self):
     self.clazz_log = logging.get_logger()
Exemplo n.º 9
0
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch.distributed.elastic.utils.logging as logging

log = logging.get_logger()


class LoggingTest(unittest.TestCase):
    def setUp(self):
        self.clazz_log = logging.get_logger()

    def test_logger_name(self):
        local_log = logging.get_logger()
        name_override_log = logging.get_logger("foobar")

        self.assertEqual("logging_test", log.name)
        self.assertEqual("logging_test", self.clazz_log.name)
        self.assertEqual("logging_test", local_log.name)
        self.assertEqual("foobar", name_override_log.name)

    def test_derive_module_name(self):
        module_name = logging._derive_module_name(depth=1)
        self.assertEqual("logging_test", module_name)
Exemplo n.º 10
0
 def setUp(self):
     super().setUp()
     self.clazz_log = logging.get_logger()