예제 #1
0
    def test_get_or_default(self):

        params = RendezvousParameters(
            backend="foobar",
            endpoint="localhost",
            run_id="1234",
            min_nodes=1,
            max_nodes=1,
            timeout1=10,
        )

        self.assertEqual(10, params.get("timeout1", 20))
        self.assertEqual(60, params.get("timeout2", 60))
예제 #2
0
def create_rdzv_handler(params: RendezvousParameters) -> RendezvousHandler:
    """
    Usage:

    ::

    rdzv_params = RendezvousParameters(
                        backend="etcd",
                        endpoint="192.168.0.42:2379",
                        run_id="123",
                        min_nodes=4,
                        max_nodes=8,
                        timeout=300,
                        last_call_timeout=30,
                        etcd_prefix="custom_prefix",
                        protocol="https",
                        cacert="/etc/kubernetes/certs/ca.crt",
                        cert="/etc/kubernetes/certs/client.crt",
                        key="/etc/kubernetes/certs/client.key")
    # -- or --
    rdzv_params = RendezvousParameters(
                        backend="etcd",
                        endpoint="192.168.0.42:2379",
                        run_id="123",
                        min_nodes=4,
                        max_nodes=8)

    etcd_rdzv_handler = create_etcd_rendezvous_handler(rdzv_params)


    Where:
        run_id - unique id for this training job instance,
        min_nodes - min number of workers expected to join the rendezvous,
        max_nodes - max number of workers allowed to join the rendezvous,
                        defaults to min_workers is not specified.
        timeout - total timeout within which next_rendezvous is expected to
                      succeed; a RendezvousTimeoutError is raised otherwise;
                      Defaults is 600 (10 minutes).
        last_call_timeout - additional wait amount ("last call") after
                            min number of workers has been reached.
                            Defaults to 30 seconds.
        etcd_prefix - path prefix (from etcd root), inside which all
                      etcd nodes will be created.
                      Default is "/torchelastic/p2p".
        protocol - http (default) or https to access etcd.
        cacert - CA cert to access etcd, only makes sense with https.
        cert - client cert to access etcd, only makes sense with https.
        key - client key to access etcd, only makes sense with https.
    """
    client = _create_etcd_client(params)

    etcd_prefix = params.get("etcd_prefix", "/torchelastic/p2p")

    rdzv = EtcdRendezvous(
        client=client,
        prefix=etcd_prefix,
        run_id=params.run_id,
        num_min_workers=params.min_nodes,
        num_max_workers=params.max_nodes,
        timeout=params.get_as_int("timeout", _DEFAULT_TIMEOUT),
        last_call_timeout=params.get_as_int("last_call_timeout",
                                            _DEFAULT_LAST_CALL_TIMEOUT),
    )
    return EtcdRendezvousHandler(rdzv_impl=rdzv)
class CreateBackendTest(TestCase):
    def setUp(self) -> None:
        # For testing, the default parameters used are for tcp. If a test
        # uses parameters for file store, we set the self._params to
        # self._params_filestore.
        self._params = RendezvousParameters(
            backend="dummy_backend",
            endpoint="localhost:29300",
            run_id="dummy_run_id",
            min_nodes=1,
            max_nodes=1,
            is_host="true",
            store_type="tCp",
            read_timeout="10",
        )

        _, tmp_path = tempfile.mkstemp()

        # Parameters for filestore testing.
        self._params_filestore = RendezvousParameters(
            backend="dummy_backend",
            endpoint=tmp_path,
            run_id="dummy_run_id",
            min_nodes=1,
            max_nodes=1,
            store_type="fIlE",
        )
        self._expected_endpoint_file = tmp_path
        self._expected_temp_dir = tempfile.gettempdir()

        self._expected_endpoint_host = "localhost"
        self._expected_endpoint_port = 29300
        self._expected_store_type = TCPStore
        self._expected_read_timeout = timedelta(seconds=10)

    def tearDown(self) -> None:
        os.remove(self._expected_endpoint_file)

    def _run_test_with_store(self, store_type: str, test_to_run: Callable):
        """
        Use this function to specify the store type to use in a test. If
        not used, the test will default to TCPStore.
        """
        if store_type == "file":
            self._params = self._params_filestore
            self._expected_store_type = FileStore
            self._expected_read_timeout = timedelta(seconds=300)

        test_to_run()

    def _assert_create_backend_returns_backend(self) -> None:
        backend, store = create_backend(self._params)

        self.assertEqual(backend.name, "c10d")

        self.assertIsInstance(store, self._expected_store_type)

        typecast_store = cast(self._expected_store_type, store)
        self.assertEqual(
            typecast_store.timeout,
            self._expected_read_timeout)  # type: ignore[attr-defined]
        if (self._expected_store_type == TCPStore):
            self.assertEqual(
                typecast_store.host,
                self._expected_endpoint_host)  # type: ignore[attr-defined]
            self.assertEqual(
                typecast_store.port,
                self._expected_endpoint_port)  # type: ignore[attr-defined]
        if (self._expected_store_type == FileStore):
            if self._params.endpoint:
                self.assertEqual(
                    typecast_store.path,
                    self._expected_endpoint_file)  # type: ignore[attr-defined]
            else:
                self.assertTrue(
                    typecast_store.path.startswith(
                        self._expected_temp_dir))  # type: ignore[attr-defined]

        backend.set_state(b"dummy_state")

        state = store.get("torch.rendezvous." + self._params.run_id)

        self.assertEqual(state, b64encode(b"dummy_state"))

    def test_create_backend_returns_backend(self) -> None:
        for store_type in ["tcp", "file"]:
            with self.subTest(store_type=store_type):
                self._run_test_with_store(
                    store_type, self._assert_create_backend_returns_backend)

    def test_create_backend_returns_backend_if_is_host_is_false(self) -> None:
        store = TCPStore(  # type: ignore[call-arg] # noqa: F841
            self._expected_endpoint_host,
            self._expected_endpoint_port,
            is_master=True)

        self._params.config["is_host"] = "false"

        self._assert_create_backend_returns_backend()

    def test_create_backend_returns_backend_if_is_host_is_not_specified(
            self) -> None:
        del self._params.config["is_host"]

        self._assert_create_backend_returns_backend()

    def test_create_backend_returns_backend_if_is_host_is_not_specified_and_store_already_exists(
        self, ) -> None:
        store = TCPStore(  # type: ignore[call-arg] # noqa: F841
            self._expected_endpoint_host,
            self._expected_endpoint_port,
            is_master=True)

        del self._params.config["is_host"]

        self._assert_create_backend_returns_backend()

    def test_create_backend_returns_backend_if_endpoint_port_is_not_specified(
            self) -> None:
        self._params.endpoint = self._expected_endpoint_host

        self._expected_endpoint_port = 29400

        self._assert_create_backend_returns_backend()

    def test_create_backend_returns_backend_if_endpoint_file_is_not_specified(
            self) -> None:
        self._params_filestore.endpoint = ""

        self._run_test_with_store("file",
                                  self._assert_create_backend_returns_backend)

    def test_create_backend_returns_backend_if_store_type_is_not_specified(
            self) -> None:
        del self._params.config["store_type"]

        self._expected_store_type = TCPStore
        if (not self._params.get("read_timeout")):
            self._expected_read_timeout = timedelta(seconds=60)

        self._assert_create_backend_returns_backend()

    def test_create_backend_returns_backend_if_read_timeout_is_not_specified(
            self) -> None:
        del self._params.config["read_timeout"]

        self._expected_read_timeout = timedelta(seconds=60)

        self._assert_create_backend_returns_backend()

    def test_create_backend_raises_error_if_store_is_unreachable(self) -> None:
        self._params.config["is_host"] = "false"
        self._params.config["read_timeout"] = "2"

        with self.assertRaisesRegex(
                RendezvousConnectionError,
                r"^The connection to the C10d store has failed. See inner exception for details.$",
        ):
            create_backend(self._params)

    def test_create_backend_raises_error_if_endpoint_is_invalid(self) -> None:
        for is_host in [True, False]:
            with self.subTest(is_host=is_host):
                self._params.config["is_host"] = str(is_host)

                self._params.endpoint = "dummy_endpoint"

                with self.assertRaisesRegex(
                        RendezvousConnectionError,
                        r"^The connection to the C10d store has failed. See inner exception for "
                        r"details.$",
                ):
                    create_backend(self._params)

    def test_create_backend_raises_error_if_store_type_is_invalid(
            self) -> None:
        self._params.config["store_type"] = "dummy_store_type"

        with self.assertRaisesRegex(
                ValueError,
                r"^Invalid store type given. Currently only supports file and tcp.$"
        ):
            create_backend(self._params)

    def test_create_backend_raises_error_if_read_timeout_is_invalid(
            self) -> None:
        for read_timeout in ["0", "-10"]:
            with self.subTest(read_timeout=read_timeout):
                self._params.config["read_timeout"] = read_timeout

                with self.assertRaisesRegex(
                        ValueError,
                        r"^The read timeout must be a positive integer.$"):
                    create_backend(self._params)

    @mock.patch("tempfile.mkstemp")
    def test_create_backend_raises_error_if_tempfile_creation_fails(
            self, tempfile_mock) -> None:
        tempfile_mock.side_effect = OSError("test error")
        # Set the endpoint to empty so it defaults to creating a temp file
        self._params_filestore.endpoint = ""
        with self.assertRaisesRegex(
                RendezvousError,
                r"The file creation for C10d store has failed. See inner exception for details."
        ):
            create_backend(self._params_filestore)

    @mock.patch(
        "torch.distributed.elastic.rendezvous.c10d_rendezvous_backend.FileStore"
    )
    def test_create_backend_raises_error_if_file_path_is_invalid(
            self, filestore_mock) -> None:
        filestore_mock.side_effect = RuntimeError("test error")
        self._params_filestore.endpoint = "bad file path"
        with self.assertRaisesRegex(
                RendezvousConnectionError,
                r"^The connection to the C10d store has failed. See inner exception for "
                r"details.$",
        ):
            create_backend(self._params_filestore)