def get_checkpoint_file(config: Config, checkpoint_argument: str = "default"): """ Gets the path to a checkpoint file based on a config. Args: config: config specifying the folder checkpoint_argument: Which checkpoint to use: 'default', 'last', 'best', a number or a file name Returns: path to a checkpoint file """ if checkpoint_argument == "default": if config.get("job.type") in ["eval", "valid"]: checkpoint_file = config.checkpoint_file("best") else: last_epoch = config.last_checkpoint_number() if last_epoch is None: checkpoint_file = None else: checkpoint_file = config.checkpoint_file(last_epoch) elif is_number(checkpoint_argument, int) or checkpoint_argument == "best": checkpoint_file = config.checkpoint_file(checkpoint_argument) else: # otherwise, treat it as a filename checkpoint_file = checkpoint_argument return checkpoint_file
def checkpoint_file(self, cpt_id: Union[str, int]) -> str: "Return path of checkpoint file for given checkpoint id" from kge.misc import is_number if is_number(cpt_id, int): return os.path.join(self.folder, "checkpoint_{:05d}.pt".format(int(cpt_id))) else: return os.path.join(self.folder, "checkpoint_{}.pt".format(cpt_id))
def checkpoint_file(self, cpt_id: Union[str, int]) -> str: "Return path of checkpoint file for given checkpoint id" from kge.misc import is_number folder = self.folder # todo: find a better way to go into root dir to store checkpoints in # distributed setup if "worker" in os.path.basename(folder): folder = os.path.dirname(folder) if is_number(cpt_id, int): return os.path.join(folder, "checkpoint_{:05d}.pt".format(int(cpt_id))) else: return os.path.join(folder, "checkpoint_{}.pt".format(cpt_id))
def set(self, key: str, value, create=False, overwrite=Overwrite.Yes, log=False) -> Any: """Set value of specified key. Nested dictionary values can be accessed via "." (e.g., "job.type"). If ``create`` is ``False`` , raises :class:`ValueError` when the key does not exist already; otherwise, the new key-value pair is inserted into the configuration. """ from kge.misc import is_number splits = key.split(".") data = self.options # flatten path and see if it is valid to be set path = [] for i in range(len(splits) - 1): if splits[i] in data: create = create or "+++" in data[splits[i]] else: if create: data[splits[i]] = dict() else: msg = ( "Key '{}' cannot be set because key '{}' does not exist " "and no new keys are allowed to be created ").format( key, ".".join(splits[:(i + 1)])) if i == 0: raise KeyError(msg + "at root level.") else: raise KeyError( msg + "under key '{}'.".format(".".join(splits[:i]))) path.append(splits[i]) data = data[splits[i]] # check correctness of value try: current_value = data.get(splits[-1]) except: raise Exception( "These config entries {} {} caused an error.".format( data, splits[-1])) if current_value is None: if not create: msg = ( f"Key '{key}' cannot be set because it does not exist and " "no new keys are allowed to be created ") if len(path) == 0: raise KeyError(msg + "at root level.") else: raise KeyError(msg + ("under key '{}'.").format(".".join(path))) if isinstance(value, str) and is_number(value, int): value = int(value) elif isinstance(value, str) and is_number(value, float): value = float(value) else: if (isinstance(value, str) and isinstance(current_value, float) and is_number(value, float)): value = float(value) elif (isinstance(value, str) and isinstance(current_value, int) and is_number(value, int)): value = int(value) if type(value) != type(current_value): raise ValueError( "key '{}' has incorrect type (expected {}, found {})". format(key, type(current_value), type(value))) if overwrite == Config.Overwrite.No: return current_value if overwrite == Config.Overwrite.Error and value != current_value: raise ValueError("key '{}' cannot be overwritten".format(key)) # all fine, set value data[splits[-1]] = value if log: self.log("Set {}={}".format(key, value)) return value
def main(): # default config config = Config() # now parse the arguments parser = create_parser(config) args, unknown_args = parser.parse_known_args() # If there where unknown args, add them to the parser and reparse. The correctness # of these arguments will be checked later. if len(unknown_args) > 0: parser = create_parser( config, filter(lambda a: a.startswith("--"), unknown_args) ) args = parser.parse_args() # process meta-commands process_meta_command(args, "create", {"command": "start", "run": False}) process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"}) process_meta_command( args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"} ) process_meta_command( args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"} ) # dump command if args.command == "dump": dump(args) exit() # start command if args.command == "start": # use toy config file if no config given if args.config is None: args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml" print("WARNING: No configuration specified; using " + args.config) print("Loading configuration {}...".format(args.config)) config.load(args.config) # resume command if args.command == "resume": if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"): args.config += "/config.yaml" print("Resuming from configuration {}...".format(args.config)) config.load(args.config) config.folder = os.path.dirname(args.config) if not config.folder: config.folder = "." if not os.path.exists(config.folder): raise ValueError( "{} is not a valid config file for resuming".format(args.config) ) # overwrite configuration with command line arguments for key, value in vars(args).items(): if key in [ "command", "config", "run", "folder", "checkpoint", "abort_when_cache_outdated", ]: continue if value is not None: if key == "search.device_pool": value = "".join(value).split(",") try: if isinstance(config.get(key), bool): value = argparse_bool_type(value) except KeyError: pass config.set(key, value) if key == "model": config._import(value) # initialize output folder if args.command == "start": if args.folder is None: # means: set default config_name = os.path.splitext(os.path.basename(args.config))[0] config.folder = os.path.join( kge_base_dir(), "local", "experiments", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name, ) else: config.folder = args.folder # catch errors to log them try: if args.command == "start" and not config.init_folder(): raise ValueError("output folder {} exists already".format(config.folder)) config.log("Using folder: {}".format(config.folder)) # determine checkpoint to resume (if any) if hasattr(args, "checkpoint"): if args.checkpoint == "default": if config.get("job.type") in ["eval", "valid"]: checkpoint_file = config.checkpoint_file("best") else: checkpoint_file = None # means last elif is_number(args.checkpoint, int) or args.checkpoint == "best": checkpoint_file = config.checkpoint_file(args.checkpoint) else: # otherwise, treat it as a filename checkpoint_file = args.checkpoint # disable processing of outdated cached dataset files globally Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated # log configuration config.log("Configuration:") config.log(yaml.dump(config.options), prefix=" ") config.log("git commit: {}".format(get_git_revision_short_hash()), prefix=" ") # set random seeds if config.get("random_seed.python") > -1: import random random.seed(config.get("random_seed.python")) if config.get("random_seed.torch") > -1: import torch torch.manual_seed(config.get("random_seed.torch")) if config.get("random_seed.numpy") > -1: import numpy.random numpy.random.seed(config.get("random_seed.numpy")) # let's go if args.command == "start" and not args.run: config.log("Job created successfully.") else: # load data dataset = Dataset.load(config) # let's go job = Job.create(config, dataset) if args.command == "resume": job.resume(checkpoint_file) job.run() except BaseException as e: tb = traceback.format_exc() config.log(tb, echo=False) raise e from None