예제 #1
0
    def run(self):
        # read grid search options range
        all_keys = []
        all_keys_short = []
        all_values = []
        all_indexes = []
        grid_configs = self.config.get("grid_search.parameters")
        for k, v in sorted(Config.flatten(grid_configs).items()):
            all_keys.append(k)
            short_key = k[k.rfind(".") + 1:]
            if "_" in short_key:
                # just keep first letter after each _
                all_keys_short.append("".join(
                    map(lambda s: s[0], short_key.split("_"))))
            else:
                # keep up to three letters
                all_keys_short.append(short_key[:3])
            all_values.append(v)
            all_indexes.append(range(len(v)))

        # create search configs
        search_configs = []
        for indexes in itertools.product(*all_indexes):
            # obtain values for changed parameters
            values = list(
                map(lambda ik: all_values[ik[0]][ik[1]],
                    enumerate(list(indexes))))

            # create search configuration and check whether correct
            dummy_config = self.config.clone()
            search_config = Config(load_default=False)
            search_config.options["folder"] = "_".join(
                map(lambda i: all_keys_short[i] + str(values[i]),
                    range(len(values))))
            for i, key in enumerate(all_keys):
                dummy_config.set(key,
                                 values[i])  # to test whether correct k/v pair
                search_config.set(key, values[i], create=True)

            # and remember it
            search_configs.append(search_config.options)

        # create configuration file of search job
        self.config.set("search.type", "manual")
        self.config.set("manual_search.configurations", search_configs)
        self.config.save(os.path.join(self.config.folder, "config.yaml"))

        # and run it
        if self.config.get("grid_search.run"):
            job = Job.create(self.config, self.dataset, parent_job=self)
            job.resume()
            job.run()
        else:
            self.config.log(
                "Skipping running of search job as requested by user...")
예제 #2
0
파일: search.py 프로젝트: uma-pi1/kge
def _run_train_job(sicnk, device=None):
    """Runs a training job and returns the trace entry of its best validation result.

    Also takes are of appropriate tracing.

    """

    search_job, train_job_index, train_job_config, train_job_count, trace_keys = sicnk

    try:
        # load the job
        if device is not None:
            train_job_config.set("job.device", device)
        search_job.config.log(
            "Starting training job {} ({}/{}) on device {}...".format(
                train_job_config.folder,
                train_job_index + 1,
                train_job_count,
                train_job_config.get("job.device"),
            ))
        checkpoint_file = get_checkpoint_file(train_job_config)
        if checkpoint_file is not None:
            checkpoint = load_checkpoint(checkpoint_file,
                                         train_job_config.get("job.device"))
            job = Job.create_from(
                checkpoint=checkpoint,
                new_config=train_job_config,
                dataset=search_job.dataset,
                parent_job=search_job,
            )
        else:
            job = Job.create(
                config=train_job_config,
                dataset=search_job.dataset,
                parent_job=search_job,
            )

        # process the trace entries to far (in case of a resumed job)
        metric_name = search_job.config.get("valid.metric")
        valid_trace = []

        def copy_to_search_trace(job, trace_entry=None):
            if trace_entry is None:
                trace_entry = job.valid_trace[-1]
            trace_entry = copy.deepcopy(trace_entry)
            for key in trace_keys:
                # Process deprecated options to some extent. Support key renames, but
                # not value renames.
                actual_key = {key: None}
                _process_deprecated_options(actual_key)
                if len(actual_key) > 1:
                    raise KeyError(
                        f"{key} is deprecated but cannot be handled automatically"
                    )
                actual_key = next(iter(actual_key.keys()))
                value = train_job_config.get(actual_key)
                trace_entry[key] = value

            trace_entry["folder"] = os.path.split(train_job_config.folder)[1]
            metric_value = Trace.get_metric(trace_entry, metric_name)
            trace_entry["metric_name"] = metric_name
            trace_entry["metric_value"] = metric_value
            trace_entry["parent_job_id"] = search_job.job_id
            search_job.config.trace(**trace_entry)
            valid_trace.append(trace_entry)

        for trace_entry in job.valid_trace:
            copy_to_search_trace(None, trace_entry)

        # run the job (adding new trace entries as we go)
        # TODO make this less hacky (easier once integrated into SearchJob)
        from kge.job import ManualSearchJob

        if not isinstance(
                search_job,
                ManualSearchJob) or search_job.config.get("manual_search.run"):
            job.post_valid_hooks.append(copy_to_search_trace)
            job.run()
        else:
            search_job.config.log(
                "Skipping running of training job as requested by user.")
            return (train_job_index, None, None)

        # analyze the result
        search_job.config.log("Best result in this training job:")
        best = None
        best_metric = None
        for trace_entry in valid_trace:
            metric = trace_entry["metric_value"]
            if not best or Metric(search_job).better(metric, best_metric):
                best = trace_entry
                best_metric = metric

        # record the best result of this job
        best["child_job_id"] = best["job_id"]
        for k in ["job", "job_id", "type", "parent_job_id", "scope", "event"]:
            if k in best:
                del best[k]
        search_job.trace(
            event="search_completed",
            echo=True,
            echo_prefix="  ",
            log=True,
            scope="train",
            **best,
        )

        # force releasing the GPU memory of the job to avoid memory leakage
        del job
        gc.collect()

        return (train_job_index, best, best_metric)
    except BaseException as e:
        search_job.config.log("Trial {:05d} failed: {}".format(
            train_job_index, repr(e)))
        if search_job.on_error == "continue":
            return (train_job_index, None, None)
        else:
            search_job.config.log(
                "Aborting search due to failure of trial {:05d}".format(
                    train_job_index))
            raise e
예제 #3
0
    def run(self):
        torch_device = self.config.get("job.device")
        if self.config.get("job.device") == "cuda":
            torch_device = "cuda:0"
        if torch_device != "cpu":
            torch.cuda.set_device(torch_device)
        # seeds need to be set in every process
        set_seeds(self.config, self.rank)

        os.environ["MASTER_ADDR"] = self.config.get("job.distributed.master_ip")
        os.environ["MASTER_PORT"] = self.config.get("job.distributed.master_port")
        min_rank = get_min_rank(self.config)
        print("before init", self.rank + min_rank)
        dist.init_process_group(
            backend="gloo",
            init_method="env://",
            world_size=self.num_total_workers + min_rank,
            rank=self.rank + min_rank,
            timeout=datetime.timedelta(hours=6),
        )
        worker_ranks = list(range(min_rank, self.num_total_workers+min_rank))
        worker_group = dist.new_group(worker_ranks, timeout=datetime.timedelta(hours=6))

        # create parameter server
        server = None
        if self.config.get("job.distributed.parameter_server") == "lapse":
            os.environ["DMLC_NUM_WORKER"] = "0"
            os.environ["DMLC_NUM_SERVER"] = str(self.num_total_workers)
            os.environ["DMLC_ROLE"] = "server"
            os.environ["DMLC_PS_ROOT_URI"] = self.config.get(
                "job.distributed.master_ip"
            )
            os.environ["DMLC_PS_ROOT_PORT"] = self.config.get(
                "job.distributed.lapse_port"
            )

            num_workers_per_server = 1
            lapse.setup(self.num_keys, num_workers_per_server)
            server = lapse.Server(self.num_keys, self.embedding_dim + self.optimizer_dim)
        elif self.config.get("job.distributed.parameter_server") == "shared":
            server = self.parameters

        # create train-worker config, dataset and folder
        device_pool: list = self.config.get("job.device_pool")
        if len(device_pool) == 0:
            device_pool.append(self.config.get("job.device"))
        worker_id = self.rank
        config = deepcopy(self.config)
        config.set("job.device", device_pool[worker_id % len(device_pool)])
        config.folder = os.path.join(self.config.folder, f"worker-{self.rank}")
        config.init_folder()
        dataset = deepcopy(self.dataset)

        parameter_client = KgeParameterClient.create(
            client_type=self.config.get("job.distributed.parameter_server"),
            server_id=0,
            client_id=worker_id + min_rank,
            embedding_dim=self.embedding_dim + self.optimizer_dim,
            server=server,
            num_keys=self.num_keys,
            num_meta_keys=self.num_meta_keys,
            worker_group=worker_group,
        )
        # don't re-initialize the model after loading checkpoint
        init_for_load_only = self.checkpoint_name is not None
        job = Job.create(
            config=config,
            dataset=dataset,
            parameter_client=parameter_client,
            init_for_load_only=init_for_load_only,
        )
        if self.checkpoint_name is not None:
            checkpoint = load_checkpoint(self.checkpoint_name)
            job._load(checkpoint)
            job.load_distributed(checkpoint_name=self.checkpoint_name)

        job.run()

        # all done, clean up
        print("shut down everything")
        parameter_client.barrier()
        if hasattr(job, "work_scheduler_client"):
            job.work_scheduler_client.shutdown()
        parameter_client.shutdown()
        # delete all occurrences of the parameter client to properly shutdown lapse
        # del job
        del job.parameter_client
        del job.model.get_s_embedder().parameter_client
        del job.model.get_p_embedder().parameter_client
        del job.model
        if hasattr(job, "optimizer"):
            del job.optimizer
        del parameter_client
        gc.collect()  # make sure lapse-worker destructor is called
        # shutdown server
        if server is not None and type(server) != torch.Tensor:
            server.shutdown()
        if self.result_pipe is not None:
            if hasattr(job, "valid_trace"):
                # if we valid from checkpoint there is no valid trace
                self.result_pipe.send(job.valid_trace)
            else:
                self.result_pipe.send(None)
예제 #4
0
def main():
    # default config
    config = Config()

    # now parse the arguments
    parser = create_parser(config)
    args, unknown_args = parser.parse_known_args()

    # If there where unknown args, add them to the parser and reparse. The correctness
    # of these arguments will be checked later.
    if len(unknown_args) > 0:
        parser = create_parser(
            config, filter(lambda a: a.startswith("--"), unknown_args)
        )
        args = parser.parse_args()

    # process meta-commands
    process_meta_command(args, "create", {"command": "start", "run": False})
    process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"})
    process_meta_command(
        args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"}
    )
    process_meta_command(
        args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"}
    )
    # dump command
    if args.command == "dump":
        dump(args)
        exit()

    # package command
    if args.command == "package":
        package_model(args)
        exit()

    # start command
    if args.command == "start":
        # use toy config file if no config given
        if args.config is None:
            args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml"
            print(
                "WARNING: No configuration specified; using " + args.config,
                file=sys.stderr,
            )

        if not vars(args)["console.quiet"]:
            print("Loading configuration {}...".format(args.config))
        config.load(args.config)

    # resume command
    if args.command == "resume":
        if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"):
            args.config += "/config.yaml"
        if not vars(args)["console.quiet"]:
            print("Resuming from configuration {}...".format(args.config))
        config.load(args.config)
        config.folder = os.path.dirname(args.config)
        if not config.folder:
            config.folder = "."
        if not os.path.exists(config.folder):
            raise ValueError(
                "{} is not a valid config file for resuming".format(args.config)
            )

    # overwrite configuration with command line arguments
    for key, value in vars(args).items():
        if key in [
            "command",
            "config",
            "run",
            "folder",
            "checkpoint",
            "abort_when_cache_outdated",
        ]:
            continue
        if value is not None:
            if key == "search.device_pool":
                value = "".join(value).split(",")
            try:
                if isinstance(config.get(key), bool):
                    value = argparse_bool_type(value)
            except KeyError:
                pass
            config.set(key, value)
            if key == "model":
                config._import(value)

    # initialize output folder
    if args.command == "start":
        if args.folder is None:  # means: set default
            config_name = os.path.splitext(os.path.basename(args.config))[0]
            config.folder = os.path.join(
                kge_base_dir(),
                "local",
                "experiments",
                datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name,
            )
        else:
            config.folder = args.folder

    # catch errors to log them
    try:
        if args.command == "start" and not config.init_folder():
            raise ValueError("output folder {} exists already".format(config.folder))
        config.log("Using folder: {}".format(config.folder))

        # determine checkpoint to resume (if any)
        if hasattr(args, "checkpoint"):
            checkpoint_file = get_checkpoint_file(config, args.checkpoint)

        # disable processing of outdated cached dataset files globally
        Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated

        # log configuration
        config.log("Configuration:")
        config.log(yaml.dump(config.options), prefix="  ")
        config.log("git commit: {}".format(get_git_revision_short_hash()), prefix="  ")

        # set random seeds
        def get_seed(what):
            seed = config.get(f"random_seed.{what}")
            if seed < 0 and config.get(f"random_seed.default") >= 0:
                import hashlib

                # we add an md5 hash to the default seed so that different PRNGs get a
                # different seed
                seed = (
                    config.get(f"random_seed.default")
                    + int(hashlib.md5(what.encode()).hexdigest(), 16)
                ) % 0xFFFF  # stay 32-bit

            return seed

        if get_seed("python") > -1:
            import random

            random.seed(get_seed("python"))
        if get_seed("torch") > -1:
            import torch

            torch.manual_seed(get_seed("torch"))
        if get_seed("numpy") > -1:
            import numpy.random

            numpy.random.seed(get_seed("numpy"))
        if get_seed("numba") > -1:
            import numpy as np, numba

            @numba.njit
            def seed_numba(seed):
                np.random.seed(seed)

            seed_numba(get_seed("numba"))

        # let's go
        if args.command == "start" and not args.run:
            config.log("Job created successfully.")
        else:
            # load data
            dataset = Dataset.create(config)

            # let's go
            if args.command == "resume":
                if checkpoint_file is not None:
                    checkpoint = load_checkpoint(
                        checkpoint_file, config.get("job.device")
                    )
                    job = Job.create_from(
                        checkpoint, new_config=config, dataset=dataset
                    )
                else:
                    job = Job.create(config, dataset)
                    job.config.log(
                        "No checkpoint found or specified, starting from scratch..."
                    )
            else:
                job = Job.create(config, dataset)
            job.run()
    except BaseException:
        tb = traceback.format_exc()
        config.log(tb, echo=False)
        raise
예제 #5
0
파일: cli.py 프로젝트: Allensmile/kge-1
def main():
    # default config
    config = Config()

    # now parse the arguments
    parser = create_parser(config)
    args, unknown_args = parser.parse_known_args()

    # If there where unknown args, add them to the parser and reparse. The correctness
    # of these arguments will be checked later.
    if len(unknown_args) > 0:
        parser = create_parser(
            config, filter(lambda a: a.startswith("--"), unknown_args)
        )
        args = parser.parse_args()

    # process meta-commands
    process_meta_command(args, "create", {"command": "start", "run": False})
    process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"})
    process_meta_command(
        args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"}
    )
    process_meta_command(
        args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"}
    )
    # dump command
    if args.command == "dump":
        dump(args)
        exit()

    # start command
    if args.command == "start":
        # use toy config file if no config given
        if args.config is None:
            args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml"
            print("WARNING: No configuration specified; using " + args.config)

        print("Loading configuration {}...".format(args.config))
        config.load(args.config)

    # resume command
    if args.command == "resume":
        if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"):
            args.config += "/config.yaml"
        print("Resuming from configuration {}...".format(args.config))
        config.load(args.config)
        config.folder = os.path.dirname(args.config)
        if not config.folder:
            config.folder = "."
        if not os.path.exists(config.folder):
            raise ValueError(
                "{} is not a valid config file for resuming".format(args.config)
            )

    # overwrite configuration with command line arguments
    for key, value in vars(args).items():
        if key in [
            "command",
            "config",
            "run",
            "folder",
            "checkpoint",
            "abort_when_cache_outdated",
        ]:
            continue
        if value is not None:
            if key == "search.device_pool":
                value = "".join(value).split(",")
            try:
                if isinstance(config.get(key), bool):
                    value = argparse_bool_type(value)
            except KeyError:
                pass
            config.set(key, value)
            if key == "model":
                config._import(value)

    # initialize output folder
    if args.command == "start":
        if args.folder is None:  # means: set default
            config_name = os.path.splitext(os.path.basename(args.config))[0]
            config.folder = os.path.join(
                kge_base_dir(),
                "local",
                "experiments",
                datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name,
            )
        else:
            config.folder = args.folder

    # catch errors to log them
    try:
        if args.command == "start" and not config.init_folder():
            raise ValueError("output folder {} exists already".format(config.folder))
        config.log("Using folder: {}".format(config.folder))

        # determine checkpoint to resume (if any)
        if hasattr(args, "checkpoint"):
            if args.checkpoint == "default":
                if config.get("job.type") in ["eval", "valid"]:
                    checkpoint_file = config.checkpoint_file("best")
                else:
                    checkpoint_file = None  # means last
            elif is_number(args.checkpoint, int) or args.checkpoint == "best":
                checkpoint_file = config.checkpoint_file(args.checkpoint)
            else:
                # otherwise, treat it as a filename
                checkpoint_file = args.checkpoint

        # disable processing of outdated cached dataset files globally
        Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated

        # log configuration
        config.log("Configuration:")
        config.log(yaml.dump(config.options), prefix="  ")
        config.log("git commit: {}".format(get_git_revision_short_hash()), prefix="  ")

        # set random seeds
        if config.get("random_seed.python") > -1:
            import random

            random.seed(config.get("random_seed.python"))
        if config.get("random_seed.torch") > -1:
            import torch

            torch.manual_seed(config.get("random_seed.torch"))
        if config.get("random_seed.numpy") > -1:
            import numpy.random

            numpy.random.seed(config.get("random_seed.numpy"))

        # let's go
        if args.command == "start" and not args.run:
            config.log("Job created successfully.")
        else:
            # load data
            dataset = Dataset.load(config)

            # let's go
            job = Job.create(config, dataset)
            if args.command == "resume":
                job.resume(checkpoint_file)
            job.run()
    except BaseException as e:
        tb = traceback.format_exc()
        config.log(tb, echo=False)
        raise e from None