Пример #1
0
    def __init__(
        self,
        scheduler: torch.optim.lr_scheduler._LRScheduler,
        step_mode: StepMode,
        frequency: int = 1,
    ):
        """LRScheduler constructor.

        Args:
            scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`):
                Learning rate scheduler to be used by Determined.
            step_mode (:class:`determined.pytorch.LRSchedulerStepMode`):
                The strategy Determined will use to call (or not call) scheduler.step().

                1. ``STEP_EVERY_EPOCH``: Determined will call scheduler.step() after
                   every ``frequency`` training epoch(s). No arguments will be passed to step().

                2. ``STEP_EVERY_BATCH``: Determined will call scheduler.step() after every
                   ``frequency`` training batch(es). No arguments will be passed to step().

                3. ``MANUAL_STEP``: Determined will not call scheduler.step() at all.
                   It is up to the user to decide when to call scheduler.step(),
                   and whether to pass any arguments.
            frequency:
                Sets the frequency at which the batch and epoch step modes get triggered.
        """
        check.check_not_none(scheduler)
        check.check_isinstance(step_mode, LRScheduler.StepMode)

        self._scheduler = scheduler
        self._step_mode = step_mode
        self._frequency = frequency
Пример #2
0
    def yield_checkpoint_model(
            self, wkld: workload.Workload,
            respond: workload.ResponseFunc) -> workload.Stream:
        start_time = _current_timestamp()

        # Only the chief container should checkpoint.
        if self.rendezvous_info.get_rank() != 0:
            respond(workload.Skipped())
            return

        # Save the workload completed message for after checkpoint upload completes.
        message = None  # type: Optional[workload.Response]

        def _respond(checkpoint_info: workload.Response) -> None:
            checkpoint_info = cast(Dict[str, Any], checkpoint_info)
            metadata = storage.StorageMetadata(
                storage_id,
                storage.StorageManager._list_directory(path),
                checkpoint_info.get("framework", ""),
                checkpoint_info.get("format", ""),
            )

            logging.info("Saved trial to checkpoint {}".format(
                metadata.storage_id))
            self.tensorboard_mgr.sync()

            nonlocal message
            message = {
                "type": "WORKLOAD_COMPLETED",
                "workload": wkld,
                "start_time": start_time,
                "end_time": _current_timestamp(),
                "metrics": metadata,
            }

        with self.storage_mgr.store_path() as (storage_id, path):
            yield wkld, [pathlib.Path(path)], _respond

        # Because the messaging is synchronous, the layer below us must have called _respond.
        check_not_none(message, "response function did not get called")
        message = cast(workload.Response, message)

        respond(message)
    def __init__(
        self,
        scheduler: torch.optim.lr_scheduler._LRScheduler,
        step_mode: StepMode,
        frequency: int = 1,
    ):
        """LRScheduler constructor.

        Args:
            scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`):
                Learning rate scheduler to be used by Determined.
            step_mode (:class:`determined.pytorch.LRSchedulerStepMode`):
                The strategy Determined will use to call (or not call) scheduler.step().

                1. ``STEP_EVERY_EPOCH``: Determined will call scheduler.step() after
                   every ``frequency`` training epoch(s). No arguments will be passed to step().

                2. ``STEP_EVERY_BATCH``: Determined will call scheduler.step() after every
                   ``frequency`` training batch(es). No arguments will be passed to step().
                   This option does not take into account gradient aggregation;
                   ``STEP_EVERY_OPTIMIZER_STEP`` which does is recommended.

                3. ``STEP_EVERY_OPTIMIZER_STEP``: Determined will call scheduler.step() in sync
                   with optimizer steps. With ``optimizations.aggregation_frequency`` unset, this
                   is equivalent to ``STEP_EVERY_BATCH``; when it is set, it ensures the LR
                   scheduler is stepped every _effective_ batch.

                   If the option ``frequency`` is set to some value N, Determined will step the LR
                   scheduler every N optimizer steps.

                4. ``MANUAL_STEP``: Determined will not call scheduler.step() at all.
                   It is up to the user to decide when to call scheduler.step(),
                   and whether to pass any arguments.
            frequency:
                Sets the frequency at which the batch and epoch step modes get triggered.
        """
        check.check_not_none(scheduler)
        check.check_isinstance(step_mode, LRScheduler.StepMode)

        self._scheduler = scheduler
        self._step_mode = step_mode
        self._frequency = frequency
Пример #4
0
def main(args: List[str] = sys.argv[1:]) -> None:
    try:
        parser = make_parser()
        argcomplete.autocomplete(parser)

        parsed_args = parser.parse_args(args)

        def die(message: str, always_print_traceback: bool = False) -> None:
            if always_print_traceback or debug_mode():
                import traceback

                traceback.print_exc()

            parser.exit(1, colored(message + "\n", "red"))

        v = vars(parsed_args)
        if not v.get("func"):
            parser.print_usage()
            parser.exit(2, "{}: no subcommand specified\n".format(parser.prog))

        cert_fn = str(auth.get_config_path().joinpath("master.crt"))
        if os.path.exists(cert_fn):
            api.request.set_master_cert_bundle(cert_fn)

        try:
            # For `det deploy`, skip interaction with master.
            if v.get("_command") == DEPLOY_CMD_NAME:
                parsed_args.func(parsed_args)
                return

            try:
                check_version(parsed_args)
            except requests.exceptions.SSLError:
                # An SSLError usually means that we queried a master over HTTPS and got an untrusted
                # cert, so allow the user to store and trust the current cert. (It could also mean
                # that we tried to talk HTTPS on the HTTP port, but distinguishing that based on the
                # exception is annoying, and we'll figure that out in the next step anyway.)
                addr = api.parse_master_address(parsed_args.master)
                check_not_none(addr.hostname)
                check_not_none(addr.port)
                try:
                    ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_2_METHOD)
                    conn = OpenSSL.SSL.Connection(ctx, socket.socket())
                    conn.set_tlsext_host_name(
                        cast(str, addr.hostname).encode())
                    conn.connect((addr.hostname, addr.port))
                    conn.do_handshake()
                    cert_pem_data = "".join(
                        OpenSSL.crypto.dump_certificate(
                            OpenSSL.crypto.FILETYPE_PEM, cert).decode()
                        for cert in conn.get_peer_cert_chain())
                except OpenSSL.SSL.Error:
                    die("Tried to connect over HTTPS but couldn't get a certificate from the "
                        "master; consider using HTTP")

                cert_hash = hashlib.sha256(
                    ssl.PEM_cert_to_DER_cert(cert_pem_data)).hexdigest()
                cert_fingerprint = ":".join(chunks(cert_hash, 2))

                if not render.yes_or_no(
                        "The master sent an untrusted certificate chain with this SHA256 fingerprint:\n"
                        "{}\nDo you want to trust this certificate from now on?"
                        .format(cert_fingerprint)):
                    die("Unable to verify master certificate")

                with open(cert_fn, "w") as out:
                    out.write(cert_pem_data)
                api.request.set_master_cert_bundle(cert_fn)

                check_version(parsed_args)

            parsed_args.func(parsed_args)
        except KeyboardInterrupt as e:
            raise e
        except (api.errors.BadRequestException,
                api.errors.BadResponseException) as e:
            die("Failed to {}: {}".format(parsed_args.func.__name__, e))
        except api.errors.CorruptTokenCacheException:
            die("Failed to login: Attempted to read a corrupted token cache. "
                "The store has been deleted; please try again.")
        except EnterpriseOnlyError as e:
            die(f"Determined Enterprise Edition is required for this functionality: {e}"
                )
        except Exception:
            die("Failed to {}".format(parsed_args.func.__name__),
                always_print_traceback=True)
    except KeyboardInterrupt:
        parser.exit(3, colored("Interrupting...\n", "red"))
Пример #5
0
def main(args: List[str] = sys.argv[1:], ) -> None:
    # TODO: we lazily import "det deploy" but in the future we'd want to lazily import everything.
    parser = make_parser()

    full_cmd, aliases = generate_aliases(deploy_cmd.name)
    is_deploy_cmd = len(args) > 0 and any(args[0] == alias
                                          for alias in [*aliases, full_cmd])
    if is_deploy_cmd:
        from determined.deploy.cli import args_description as deploy_args_description

        add_args(parser, [deploy_args_description])
    else:
        add_args(parser, all_args_description)

    try:
        argcomplete.autocomplete(parser)

        parsed_args = parser.parse_args(args)

        def die(message: str, always_print_traceback: bool = False) -> None:
            if always_print_traceback or debug_mode():
                import traceback

                traceback.print_exc(file=sys.stderr)

            parser.exit(1, colored(message + "\n", "red"))

        v = vars(parsed_args)
        if not v.get("func"):
            parser.print_usage()
            parser.exit(2, "{}: no subcommand specified\n".format(parser.prog))

        try:
            # For `det deploy`, skip interaction with master.
            if is_deploy_cmd:
                parsed_args.func(parsed_args)
                return

            # Configure the CLI's Cert singleton.
            certs.cli_cert = certs.default_load(parsed_args.master)

            try:
                check_version(parsed_args)
            except requests.exceptions.SSLError:
                # An SSLError usually means that we queried a master over HTTPS and got an untrusted
                # cert, so allow the user to store and trust the current cert. (It could also mean
                # that we tried to talk HTTPS on the HTTP port, but distinguishing that based on the
                # exception is annoying, and we'll figure that out in the next step anyway.)
                addr = api.parse_master_address(parsed_args.master)
                check_not_none(addr.hostname)
                check_not_none(addr.port)
                try:
                    ctx = SSL.Context(SSL.TLSv1_2_METHOD)
                    conn = SSL.Connection(ctx, socket.socket())
                    conn.set_tlsext_host_name(
                        cast(str, addr.hostname).encode())
                    conn.connect(
                        cast(Sequence[Union[str, int]],
                             (addr.hostname, addr.port)))
                    conn.do_handshake()
                    cert_pem_data = "".join(
                        crypto.dump_certificate(crypto.FILETYPE_PEM,
                                                cert).decode()
                        for cert in conn.get_peer_cert_chain())
                except crypto.Error:
                    die("Tried to connect over HTTPS but couldn't get a certificate from the "
                        "master; consider using HTTP")

                cert_hash = hashlib.sha256(
                    ssl.PEM_cert_to_DER_cert(cert_pem_data)).hexdigest()
                cert_fingerprint = ":".join(chunks(cert_hash, 2))

                if not render.yes_or_no(
                        "The master sent an untrusted certificate chain with this SHA256 fingerprint:\n"
                        "{}\nDo you want to trust this certificate from now on?"
                        .format(cert_fingerprint)):
                    die("Unable to verify master certificate")

                certs.CertStore(certs.default_store()).set_cert(
                    parsed_args.master, cert_pem_data)
                # Reconfigure the CLI's Cert singleton, but preserve the certificate name.
                old_cert_name = certs.cli_cert.name
                certs.cli_cert = certs.Cert(cert_pem=cert_pem_data,
                                            name=old_cert_name)

                check_version(parsed_args)

            parsed_args.func(parsed_args)
        except KeyboardInterrupt as e:
            raise e
        except (api.errors.BadRequestException,
                api.errors.BadResponseException) as e:
            die("Failed to {}: {}".format(parsed_args.func.__name__, e))
        except api.errors.CorruptTokenCacheException:
            die("Failed to login: Attempted to read a corrupted token cache. "
                "The store has been deleted; please try again.")
        except EnterpriseOnlyError as e:
            die(f"Determined Enterprise Edition is required for this functionality: {e}"
                )
        except Exception:
            die("Failed to {}".format(parsed_args.func.__name__),
                always_print_traceback=True)
    except KeyboardInterrupt:
        # die() may not be defined yet.
        if debug_mode():
            import traceback

            traceback.print_exc(file=sys.stderr)

        print(colored("Interrupting...\n", "red"), file=sys.stderr)
        exit(3)