예제 #1
0
def get_checkpoint_file(config: Config, checkpoint_argument: str = "default"):
    """
    Gets the path to a checkpoint file based on a config.

    Args:
        config: config specifying the folder
        checkpoint_argument: Which checkpoint to use: 'default', 'last', 'best',
                             a number or a file name

    Returns:
        path to a checkpoint file
    """
    if checkpoint_argument == "default":
        if config.get("job.type") in ["eval", "valid"]:
            checkpoint_file = config.checkpoint_file("best")
        else:
            last_epoch = config.last_checkpoint_number()
            if last_epoch is None:
                checkpoint_file = None
            else:
                checkpoint_file = config.checkpoint_file(last_epoch)
    elif is_number(checkpoint_argument, int) or checkpoint_argument == "best":
        checkpoint_file = config.checkpoint_file(checkpoint_argument)
    else:
        # otherwise, treat it as a filename
        checkpoint_file = checkpoint_argument
    return checkpoint_file
예제 #2
0
파일: config.py 프로젝트: Allensmile/kge-1
    def checkpoint_file(self, cpt_id: Union[str, int]) -> str:
        "Return path of checkpoint file for given checkpoint id"
        from kge.misc import is_number

        if is_number(cpt_id, int):
            return os.path.join(self.folder, "checkpoint_{:05d}.pt".format(int(cpt_id)))
        else:
            return os.path.join(self.folder, "checkpoint_{}.pt".format(cpt_id))
예제 #3
0
파일: config.py 프로젝트: AdrianKs/dist-kge
    def checkpoint_file(self, cpt_id: Union[str, int]) -> str:
        "Return path of checkpoint file for given checkpoint id"
        from kge.misc import is_number
        folder = self.folder
        # todo: find a better way to go into root dir to store checkpoints in
        #  distributed setup
        if "worker" in os.path.basename(folder):
            folder = os.path.dirname(folder)

        if is_number(cpt_id, int):
            return os.path.join(folder,
                                "checkpoint_{:05d}.pt".format(int(cpt_id)))
        else:
            return os.path.join(folder, "checkpoint_{}.pt".format(cpt_id))
예제 #4
0
파일: config.py 프로젝트: healx/kge
    def set(self,
            key: str,
            value,
            create=False,
            overwrite=Overwrite.Yes,
            log=False) -> Any:
        """Set value of specified key.

        Nested dictionary values can be accessed via "." (e.g., "job.type").

        If ``create`` is ``False`` , raises :class:`ValueError` when the key
        does not exist already; otherwise, the new key-value pair is inserted
        into the configuration.

        """
        from kge.misc import is_number

        splits = key.split(".")
        data = self.options

        # flatten path and see if it is valid to be set
        path = []
        for i in range(len(splits) - 1):
            if splits[i] in data:
                create = create or "+++" in data[splits[i]]
            else:
                if create:
                    data[splits[i]] = dict()
                else:
                    msg = (
                        "Key '{}' cannot be set because key '{}' does not exist "
                        "and no new keys are allowed to be created ").format(
                            key, ".".join(splits[:(i + 1)]))
                    if i == 0:
                        raise KeyError(msg + "at root level.")
                    else:
                        raise KeyError(
                            msg +
                            "under key '{}'.".format(".".join(splits[:i])))

            path.append(splits[i])
            data = data[splits[i]]

        # check correctness of value
        try:
            current_value = data.get(splits[-1])
        except:
            raise Exception(
                "These config entries {} {} caused an error.".format(
                    data, splits[-1]))

        if current_value is None:
            if not create:
                msg = (
                    f"Key '{key}' cannot be set because it does not exist and "
                    "no new keys are allowed to be created ")
                if len(path) == 0:
                    raise KeyError(msg + "at root level.")
                else:
                    raise KeyError(msg +
                                   ("under key '{}'.").format(".".join(path)))

            if isinstance(value, str) and is_number(value, int):
                value = int(value)
            elif isinstance(value, str) and is_number(value, float):
                value = float(value)
        else:
            if (isinstance(value, str) and isinstance(current_value, float)
                    and is_number(value, float)):
                value = float(value)
            elif (isinstance(value, str) and isinstance(current_value, int)
                  and is_number(value, int)):
                value = int(value)
            if type(value) != type(current_value):
                raise ValueError(
                    "key '{}' has incorrect type (expected {}, found {})".
                    format(key, type(current_value), type(value)))
            if overwrite == Config.Overwrite.No:
                return current_value
            if overwrite == Config.Overwrite.Error and value != current_value:
                raise ValueError("key '{}' cannot be overwritten".format(key))

        # all fine, set value
        data[splits[-1]] = value
        if log:
            self.log("Set {}={}".format(key, value))
        return value
예제 #5
0
파일: cli.py 프로젝트: Allensmile/kge-1
def main():
    # default config
    config = Config()

    # now parse the arguments
    parser = create_parser(config)
    args, unknown_args = parser.parse_known_args()

    # If there where unknown args, add them to the parser and reparse. The correctness
    # of these arguments will be checked later.
    if len(unknown_args) > 0:
        parser = create_parser(
            config, filter(lambda a: a.startswith("--"), unknown_args)
        )
        args = parser.parse_args()

    # process meta-commands
    process_meta_command(args, "create", {"command": "start", "run": False})
    process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"})
    process_meta_command(
        args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"}
    )
    process_meta_command(
        args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"}
    )
    # dump command
    if args.command == "dump":
        dump(args)
        exit()

    # start command
    if args.command == "start":
        # use toy config file if no config given
        if args.config is None:
            args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml"
            print("WARNING: No configuration specified; using " + args.config)

        print("Loading configuration {}...".format(args.config))
        config.load(args.config)

    # resume command
    if args.command == "resume":
        if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"):
            args.config += "/config.yaml"
        print("Resuming from configuration {}...".format(args.config))
        config.load(args.config)
        config.folder = os.path.dirname(args.config)
        if not config.folder:
            config.folder = "."
        if not os.path.exists(config.folder):
            raise ValueError(
                "{} is not a valid config file for resuming".format(args.config)
            )

    # overwrite configuration with command line arguments
    for key, value in vars(args).items():
        if key in [
            "command",
            "config",
            "run",
            "folder",
            "checkpoint",
            "abort_when_cache_outdated",
        ]:
            continue
        if value is not None:
            if key == "search.device_pool":
                value = "".join(value).split(",")
            try:
                if isinstance(config.get(key), bool):
                    value = argparse_bool_type(value)
            except KeyError:
                pass
            config.set(key, value)
            if key == "model":
                config._import(value)

    # initialize output folder
    if args.command == "start":
        if args.folder is None:  # means: set default
            config_name = os.path.splitext(os.path.basename(args.config))[0]
            config.folder = os.path.join(
                kge_base_dir(),
                "local",
                "experiments",
                datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name,
            )
        else:
            config.folder = args.folder

    # catch errors to log them
    try:
        if args.command == "start" and not config.init_folder():
            raise ValueError("output folder {} exists already".format(config.folder))
        config.log("Using folder: {}".format(config.folder))

        # determine checkpoint to resume (if any)
        if hasattr(args, "checkpoint"):
            if args.checkpoint == "default":
                if config.get("job.type") in ["eval", "valid"]:
                    checkpoint_file = config.checkpoint_file("best")
                else:
                    checkpoint_file = None  # means last
            elif is_number(args.checkpoint, int) or args.checkpoint == "best":
                checkpoint_file = config.checkpoint_file(args.checkpoint)
            else:
                # otherwise, treat it as a filename
                checkpoint_file = args.checkpoint

        # disable processing of outdated cached dataset files globally
        Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated

        # log configuration
        config.log("Configuration:")
        config.log(yaml.dump(config.options), prefix="  ")
        config.log("git commit: {}".format(get_git_revision_short_hash()), prefix="  ")

        # set random seeds
        if config.get("random_seed.python") > -1:
            import random

            random.seed(config.get("random_seed.python"))
        if config.get("random_seed.torch") > -1:
            import torch

            torch.manual_seed(config.get("random_seed.torch"))
        if config.get("random_seed.numpy") > -1:
            import numpy.random

            numpy.random.seed(config.get("random_seed.numpy"))

        # let's go
        if args.command == "start" and not args.run:
            config.log("Job created successfully.")
        else:
            # load data
            dataset = Dataset.load(config)

            # let's go
            job = Job.create(config, dataset)
            if args.command == "resume":
                job.resume(checkpoint_file)
            job.run()
    except BaseException as e:
        tb = traceback.format_exc()
        config.log(tb, echo=False)
        raise e from None