def __init__( self, scheduler: torch.optim.lr_scheduler._LRScheduler, step_mode: StepMode, frequency: int = 1, ): """LRScheduler constructor. Args: scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`): Learning rate scheduler to be used by Determined. step_mode (:class:`determined.pytorch.LRSchedulerStepMode`): The strategy Determined will use to call (or not call) scheduler.step(). 1. ``STEP_EVERY_EPOCH``: Determined will call scheduler.step() after every ``frequency`` training epoch(s). No arguments will be passed to step(). 2. ``STEP_EVERY_BATCH``: Determined will call scheduler.step() after every ``frequency`` training batch(es). No arguments will be passed to step(). 3. ``MANUAL_STEP``: Determined will not call scheduler.step() at all. It is up to the user to decide when to call scheduler.step(), and whether to pass any arguments. frequency: Sets the frequency at which the batch and epoch step modes get triggered. """ check.check_not_none(scheduler) check.check_isinstance(step_mode, LRScheduler.StepMode) self._scheduler = scheduler self._step_mode = step_mode self._frequency = frequency
def yield_checkpoint_model( self, wkld: workload.Workload, respond: workload.ResponseFunc) -> workload.Stream: start_time = _current_timestamp() # Only the chief container should checkpoint. if self.rendezvous_info.get_rank() != 0: respond(workload.Skipped()) return # Save the workload completed message for after checkpoint upload completes. message = None # type: Optional[workload.Response] def _respond(checkpoint_info: workload.Response) -> None: checkpoint_info = cast(Dict[str, Any], checkpoint_info) metadata = storage.StorageMetadata( storage_id, storage.StorageManager._list_directory(path), checkpoint_info.get("framework", ""), checkpoint_info.get("format", ""), ) logging.info("Saved trial to checkpoint {}".format( metadata.storage_id)) self.tensorboard_mgr.sync() nonlocal message message = { "type": "WORKLOAD_COMPLETED", "workload": wkld, "start_time": start_time, "end_time": _current_timestamp(), "metrics": metadata, } with self.storage_mgr.store_path() as (storage_id, path): yield wkld, [pathlib.Path(path)], _respond # Because the messaging is synchronous, the layer below us must have called _respond. check_not_none(message, "response function did not get called") message = cast(workload.Response, message) respond(message)
def __init__( self, scheduler: torch.optim.lr_scheduler._LRScheduler, step_mode: StepMode, frequency: int = 1, ): """LRScheduler constructor. Args: scheduler (:class:`torch.optim.lr_scheduler._LRScheduler`): Learning rate scheduler to be used by Determined. step_mode (:class:`determined.pytorch.LRSchedulerStepMode`): The strategy Determined will use to call (or not call) scheduler.step(). 1. ``STEP_EVERY_EPOCH``: Determined will call scheduler.step() after every ``frequency`` training epoch(s). No arguments will be passed to step(). 2. ``STEP_EVERY_BATCH``: Determined will call scheduler.step() after every ``frequency`` training batch(es). No arguments will be passed to step(). This option does not take into account gradient aggregation; ``STEP_EVERY_OPTIMIZER_STEP`` which does is recommended. 3. ``STEP_EVERY_OPTIMIZER_STEP``: Determined will call scheduler.step() in sync with optimizer steps. With ``optimizations.aggregation_frequency`` unset, this is equivalent to ``STEP_EVERY_BATCH``; when it is set, it ensures the LR scheduler is stepped every _effective_ batch. If the option ``frequency`` is set to some value N, Determined will step the LR scheduler every N optimizer steps. 4. ``MANUAL_STEP``: Determined will not call scheduler.step() at all. It is up to the user to decide when to call scheduler.step(), and whether to pass any arguments. frequency: Sets the frequency at which the batch and epoch step modes get triggered. """ check.check_not_none(scheduler) check.check_isinstance(step_mode, LRScheduler.StepMode) self._scheduler = scheduler self._step_mode = step_mode self._frequency = frequency
def main(args: List[str] = sys.argv[1:]) -> None: try: parser = make_parser() argcomplete.autocomplete(parser) parsed_args = parser.parse_args(args) def die(message: str, always_print_traceback: bool = False) -> None: if always_print_traceback or debug_mode(): import traceback traceback.print_exc() parser.exit(1, colored(message + "\n", "red")) v = vars(parsed_args) if not v.get("func"): parser.print_usage() parser.exit(2, "{}: no subcommand specified\n".format(parser.prog)) cert_fn = str(auth.get_config_path().joinpath("master.crt")) if os.path.exists(cert_fn): api.request.set_master_cert_bundle(cert_fn) try: # For `det deploy`, skip interaction with master. if v.get("_command") == DEPLOY_CMD_NAME: parsed_args.func(parsed_args) return try: check_version(parsed_args) except requests.exceptions.SSLError: # An SSLError usually means that we queried a master over HTTPS and got an untrusted # cert, so allow the user to store and trust the current cert. (It could also mean # that we tried to talk HTTPS on the HTTP port, but distinguishing that based on the # exception is annoying, and we'll figure that out in the next step anyway.) addr = api.parse_master_address(parsed_args.master) check_not_none(addr.hostname) check_not_none(addr.port) try: ctx = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_2_METHOD) conn = OpenSSL.SSL.Connection(ctx, socket.socket()) conn.set_tlsext_host_name( cast(str, addr.hostname).encode()) conn.connect((addr.hostname, addr.port)) conn.do_handshake() cert_pem_data = "".join( OpenSSL.crypto.dump_certificate( OpenSSL.crypto.FILETYPE_PEM, cert).decode() for cert in conn.get_peer_cert_chain()) except OpenSSL.SSL.Error: die("Tried to connect over HTTPS but couldn't get a certificate from the " "master; consider using HTTP") cert_hash = hashlib.sha256( ssl.PEM_cert_to_DER_cert(cert_pem_data)).hexdigest() cert_fingerprint = ":".join(chunks(cert_hash, 2)) if not render.yes_or_no( "The master sent an untrusted certificate chain with this SHA256 fingerprint:\n" "{}\nDo you want to trust this certificate from now on?" .format(cert_fingerprint)): die("Unable to verify master certificate") with open(cert_fn, "w") as out: out.write(cert_pem_data) api.request.set_master_cert_bundle(cert_fn) check_version(parsed_args) parsed_args.func(parsed_args) except KeyboardInterrupt as e: raise e except (api.errors.BadRequestException, api.errors.BadResponseException) as e: die("Failed to {}: {}".format(parsed_args.func.__name__, e)) except api.errors.CorruptTokenCacheException: die("Failed to login: Attempted to read a corrupted token cache. " "The store has been deleted; please try again.") except EnterpriseOnlyError as e: die(f"Determined Enterprise Edition is required for this functionality: {e}" ) except Exception: die("Failed to {}".format(parsed_args.func.__name__), always_print_traceback=True) except KeyboardInterrupt: parser.exit(3, colored("Interrupting...\n", "red"))
def main(args: List[str] = sys.argv[1:], ) -> None: # TODO: we lazily import "det deploy" but in the future we'd want to lazily import everything. parser = make_parser() full_cmd, aliases = generate_aliases(deploy_cmd.name) is_deploy_cmd = len(args) > 0 and any(args[0] == alias for alias in [*aliases, full_cmd]) if is_deploy_cmd: from determined.deploy.cli import args_description as deploy_args_description add_args(parser, [deploy_args_description]) else: add_args(parser, all_args_description) try: argcomplete.autocomplete(parser) parsed_args = parser.parse_args(args) def die(message: str, always_print_traceback: bool = False) -> None: if always_print_traceback or debug_mode(): import traceback traceback.print_exc(file=sys.stderr) parser.exit(1, colored(message + "\n", "red")) v = vars(parsed_args) if not v.get("func"): parser.print_usage() parser.exit(2, "{}: no subcommand specified\n".format(parser.prog)) try: # For `det deploy`, skip interaction with master. if is_deploy_cmd: parsed_args.func(parsed_args) return # Configure the CLI's Cert singleton. certs.cli_cert = certs.default_load(parsed_args.master) try: check_version(parsed_args) except requests.exceptions.SSLError: # An SSLError usually means that we queried a master over HTTPS and got an untrusted # cert, so allow the user to store and trust the current cert. (It could also mean # that we tried to talk HTTPS on the HTTP port, but distinguishing that based on the # exception is annoying, and we'll figure that out in the next step anyway.) addr = api.parse_master_address(parsed_args.master) check_not_none(addr.hostname) check_not_none(addr.port) try: ctx = SSL.Context(SSL.TLSv1_2_METHOD) conn = SSL.Connection(ctx, socket.socket()) conn.set_tlsext_host_name( cast(str, addr.hostname).encode()) conn.connect( cast(Sequence[Union[str, int]], (addr.hostname, addr.port))) conn.do_handshake() cert_pem_data = "".join( crypto.dump_certificate(crypto.FILETYPE_PEM, cert).decode() for cert in conn.get_peer_cert_chain()) except crypto.Error: die("Tried to connect over HTTPS but couldn't get a certificate from the " "master; consider using HTTP") cert_hash = hashlib.sha256( ssl.PEM_cert_to_DER_cert(cert_pem_data)).hexdigest() cert_fingerprint = ":".join(chunks(cert_hash, 2)) if not render.yes_or_no( "The master sent an untrusted certificate chain with this SHA256 fingerprint:\n" "{}\nDo you want to trust this certificate from now on?" .format(cert_fingerprint)): die("Unable to verify master certificate") certs.CertStore(certs.default_store()).set_cert( parsed_args.master, cert_pem_data) # Reconfigure the CLI's Cert singleton, but preserve the certificate name. old_cert_name = certs.cli_cert.name certs.cli_cert = certs.Cert(cert_pem=cert_pem_data, name=old_cert_name) check_version(parsed_args) parsed_args.func(parsed_args) except KeyboardInterrupt as e: raise e except (api.errors.BadRequestException, api.errors.BadResponseException) as e: die("Failed to {}: {}".format(parsed_args.func.__name__, e)) except api.errors.CorruptTokenCacheException: die("Failed to login: Attempted to read a corrupted token cache. " "The store has been deleted; please try again.") except EnterpriseOnlyError as e: die(f"Determined Enterprise Edition is required for this functionality: {e}" ) except Exception: die("Failed to {}".format(parsed_args.func.__name__), always_print_traceback=True) except KeyboardInterrupt: # die() may not be defined yet. if debug_mode(): import traceback traceback.print_exc(file=sys.stderr) print(colored("Interrupting...\n", "red"), file=sys.stderr) exit(3)