Пример #1
0
def _run_train_job(sicnk, device=None):
    """Runs a training job and returns the trace entry of its best validation result.

    Also takes are of appropriate tracing.

    """

    search_job, train_job_index, train_job_config, train_job_count, trace_keys = sicnk

    try:
        # load the job
        if device is not None:
            train_job_config.set("job.device", device)
        search_job.config.log(
            "Starting training job {} ({}/{}) on device {}...".format(
                train_job_config.folder,
                train_job_index + 1,
                train_job_count,
                train_job_config.get("job.device"),
            ))
        checkpoint_file = get_checkpoint_file(train_job_config)
        if checkpoint_file is not None:
            checkpoint = load_checkpoint(checkpoint_file,
                                         train_job_config.get("job.device"))
            job = Job.create_from(
                checkpoint=checkpoint,
                new_config=train_job_config,
                dataset=search_job.dataset,
                parent_job=search_job,
            )
        else:
            job = Job.create(
                config=train_job_config,
                dataset=search_job.dataset,
                parent_job=search_job,
            )

        # process the trace entries to far (in case of a resumed job)
        metric_name = search_job.config.get("valid.metric")
        valid_trace = []

        def copy_to_search_trace(job, trace_entry=None):
            if trace_entry is None:
                trace_entry = job.valid_trace[-1]
            trace_entry = copy.deepcopy(trace_entry)
            for key in trace_keys:
                # Process deprecated options to some extent. Support key renames, but
                # not value renames.
                actual_key = {key: None}
                _process_deprecated_options(actual_key)
                if len(actual_key) > 1:
                    raise KeyError(
                        f"{key} is deprecated but cannot be handled automatically"
                    )
                actual_key = next(iter(actual_key.keys()))
                value = train_job_config.get(actual_key)
                trace_entry[key] = value

            trace_entry["folder"] = os.path.split(train_job_config.folder)[1]
            metric_value = Trace.get_metric(trace_entry, metric_name)
            trace_entry["metric_name"] = metric_name
            trace_entry["metric_value"] = metric_value
            trace_entry["parent_job_id"] = search_job.job_id
            search_job.config.trace(**trace_entry)
            valid_trace.append(trace_entry)

        for trace_entry in job.valid_trace:
            copy_to_search_trace(None, trace_entry)

        # run the job (adding new trace entries as we go)
        # TODO make this less hacky (easier once integrated into SearchJob)
        from kge.job import ManualSearchJob

        if not isinstance(
                search_job,
                ManualSearchJob) or search_job.config.get("manual_search.run"):
            job.post_valid_hooks.append(copy_to_search_trace)
            job.run()
        else:
            search_job.config.log(
                "Skipping running of training job as requested by user.")
            return (train_job_index, None, None)

        # analyze the result
        search_job.config.log("Best result in this training job:")
        best = None
        best_metric = None
        for trace_entry in valid_trace:
            metric = trace_entry["metric_value"]
            if not best or Metric(search_job).better(metric, best_metric):
                best = trace_entry
                best_metric = metric

        # record the best result of this job
        best["child_job_id"] = best["job_id"]
        for k in ["job", "job_id", "type", "parent_job_id", "scope", "event"]:
            if k in best:
                del best[k]
        search_job.trace(
            event="search_completed",
            echo=True,
            echo_prefix="  ",
            log=True,
            scope="train",
            **best,
        )

        # force releasing the GPU memory of the job to avoid memory leakage
        del job
        gc.collect()

        return (train_job_index, best, best_metric)
    except BaseException as e:
        search_job.config.log("Trial {:05d} failed: {}".format(
            train_job_index, repr(e)))
        if search_job.on_error == "continue":
            return (train_job_index, None, None)
        else:
            search_job.config.log(
                "Aborting search due to failure of trial {:05d}".format(
                    train_job_index))
            raise e
Пример #2
0
    def _run(self) -> None:
        """Start/resume the training job and run to completion."""

        if self.is_forward_only:
            raise Exception(
                f"{self.__class__.__name__} was initialized for forward only. You can only call run_epoch()"
            )
        if self.epoch == 0:
            self.save(self.config.checkpoint_file(0))

        self.config.log("Starting training...")
        checkpoint_every = self.config.get("train.checkpoint.every")
        checkpoint_keep = self.config.get("train.checkpoint.keep")
        metric_name = self.config.get("valid.metric")
        patience = self.config.get("valid.early_stopping.patience")
        while True:
            # checking for model improvement according to metric_name
            # and do early stopping and keep the best checkpoint
            if (len(self.valid_trace) > 0
                    and self.valid_trace[-1]["epoch"] == self.epoch):
                best_index = Metric(self).best_index(
                    list(
                        map(lambda trace: trace[metric_name],
                            self.valid_trace)))
                if best_index == len(self.valid_trace) - 1:
                    self.save(self.config.checkpoint_file("best"))
                if (patience > 0 and len(self.valid_trace) > patience
                        and best_index < len(self.valid_trace) - patience):
                    self.config.log(
                        "Stopping early ({} did not improve over best result ".
                        format(metric_name) +
                        "in the last {} validation runs).".format(patience))
                    break
                if self.epoch > self.config.get(
                        "valid.early_stopping.threshold.epochs"):
                    achieved = self.valid_trace[best_index][metric_name]
                    target = self.config.get(
                        "valid.early_stopping.threshold.metric_value")
                    if Metric(self).better(target, achieved):
                        self.config.log(
                            "Stopping early ({} did not achieve threshold after {} epochs"
                            .format(metric_name, self.epoch))
                        break

            # should we stop?
            if self.epoch >= self.config.get("train.max_epochs"):
                self.config.log("Maximum number of epochs reached.")
                break

            # update learning rate if warmup is used
            if self.epoch < self._lr_warmup:
                for group in self.optimizer.param_groups:
                    group["lr"] = group["initial_lr"] * (self.epoch +
                                                         1) / self._lr_warmup

            # start a new epoch
            self.epoch += 1
            self.config.log("Starting epoch {}...".format(self.epoch))
            trace_entry = self.run_epoch()
            self.config.log("Finished epoch {}.".format(self.epoch))

            # update model metadata
            self.model.meta["train_job_trace_entry"] = self.trace_entry
            self.model.meta["train_epoch"] = self.epoch
            self.model.meta["train_config"] = self.config
            self.model.meta["train_trace_entry"] = trace_entry

            # validate
            lr_metric = None
            if (self.config.get("valid.every") > 0
                    and self.epoch % self.config.get("valid.every") == 0):
                self.valid_job.epoch = self.epoch
                trace_entry = self.valid_job.run()
                self.valid_trace.append(trace_entry)
                for f in self.post_valid_hooks:
                    f(self)
                self.model.meta["valid_trace_entry"] = trace_entry
                lr_metric = trace_entry[metric_name]

            # update learning rate after warmup
            if self.epoch >= self._lr_warmup:
                # note: lr_metric is None if no validation has been performed in this
                # epoch. This is handled by the optimizers
                self.kge_lr_scheduler.step(lr_metric)

            # create checkpoint and delete old one, if necessary
            self.save(self.config.checkpoint_file(self.epoch))
            if self.epoch > 1:
                delete_checkpoint_epoch = -1
                if checkpoint_every == 0:
                    # do not keep any old checkpoints
                    delete_checkpoint_epoch = self.epoch - 1
                elif (self.epoch - 1) % checkpoint_every != 0:
                    # delete checkpoints that are not in the checkpoint.every schedule
                    delete_checkpoint_epoch = self.epoch - 1
                elif checkpoint_keep > 0:
                    # keep a maximum number of checkpoint_keep checkpoints
                    delete_checkpoint_epoch = (
                        self.epoch - 1 - checkpoint_every * checkpoint_keep)
                if delete_checkpoint_epoch >= 0:
                    if delete_checkpoint_epoch != 0 or not self.config.get(
                            "train.checkpoint.keep_init"):
                        self._delete_checkpoint(delete_checkpoint_epoch)

        self.trace(event="train_completed")
Пример #3
0
    def _run(self) -> None:
        """Start/resume the training job and run to completion."""
        self.config.log("Starting training...")
        checkpoint_every = self.config.get("train.checkpoint.every")
        checkpoint_keep = self.config.get("train.checkpoint.keep")
        metric_name = self.config.get("valid.metric")
        patience = self.config.get("valid.early_stopping.patience")
        while True:
            # checking for model improvement according to metric_name
            # and do early stopping and keep the best checkpoint
            if (
                len(self.valid_trace) > 0
                and self.valid_trace[-1]["epoch"] == self.epoch
            ):
                best_index = Metric(self).best_index(
                    list(map(lambda trace: trace[metric_name], self.valid_trace))
                )
                if best_index == len(self.valid_trace) - 1:
                    self.save(self.config.checkpoint_file("best"))
                if (
                    patience > 0
                    and len(self.valid_trace) > patience
                    and best_index < len(self.valid_trace) - patience
                ):
                    self.config.log(
                        "Stopping early ({} did not improve over best result ".format(
                            metric_name
                        )
                        + "in the last {} validation runs).".format(patience)
                    )
                    break
                if self.epoch > self.config.get(
                    "valid.early_stopping.threshold.epochs"
                ):
                    achieved = self.valid_trace[best_index][metric_name]
                    target = self.config.get(
                        "valid.early_stopping.threshold.metric_value"
                    )
                    if Metric(self).better(target, achieved):
                        self.config.log(
                            "Stopping early ({} did not achieve threshold after {} epochs".format(
                                metric_name, self.epoch
                            )
                        )
                        break

            # should we stop?
            if self.epoch >= self.config.get("train.max_epochs"):
                self.config.log("Maximum number of epochs reached.")
                break

            # start a new epoch
            self.epoch += 1
            self.config.log("Starting epoch {}...".format(self.epoch))
            trace_entry = self.run_epoch()
            self.config.log("Finished epoch {}.".format(self.epoch))

            # update model metadata
            self.model.meta["train_job_trace_entry"] = self.trace_entry
            self.model.meta["train_epoch"] = self.epoch
            self.model.meta["train_config"] = self.config
            self.model.meta["train_trace_entry"] = trace_entry

            # validate and update learning rate
            if (
                self.config.get("valid.every") > 0
                and self.epoch % self.config.get("valid.every") == 0
            ):
                self.valid_job.epoch = self.epoch
                trace_entry = self.valid_job.run()
                self.valid_trace.append(trace_entry)
                for f in self.post_valid_hooks:
                    f(self)
                self.model.meta["valid_trace_entry"] = trace_entry

                # metric-based scheduler step
                self.kge_lr_scheduler.step(trace_entry[metric_name])
            else:
                self.kge_lr_scheduler.step()

            # create checkpoint and delete old one, if necessary
            self.save(self.config.checkpoint_file(self.epoch))
            if self.epoch > 1:
                delete_checkpoint_epoch = -1
                if checkpoint_every == 0:
                    # do not keep any old checkpoints
                    delete_checkpoint_epoch = self.epoch - 1
                elif (self.epoch - 1) % checkpoint_every != 0:
                    # delete checkpoints that are not in the checkpoint.every schedule
                    delete_checkpoint_epoch = self.epoch - 1
                elif checkpoint_keep > 0:
                    # keep a maximum number of checkpoint_keep checkpoints
                    delete_checkpoint_epoch = (
                        self.epoch - 1 - checkpoint_every * checkpoint_keep
                    )
                if delete_checkpoint_epoch > 0:
                    if os.path.exists(
                        self.config.checkpoint_file(delete_checkpoint_epoch)
                    ):
                        self.config.log(
                            "Removing old checkpoint {}...".format(
                                self.config.checkpoint_file(delete_checkpoint_epoch)
                            )
                        )
                        os.remove(self.config.checkpoint_file(delete_checkpoint_epoch))
                    else:
                        self.config.log(
                            "Could not delete old checkpoint {}, does not exits.".format(
                                self.config.checkpoint_file(delete_checkpoint_epoch)
                            )
                        )

        self.trace(event="train_completed")
Пример #4
0
    def _run(self) -> None:
        """Start/resume the training job and run to completion."""

        if self.is_forward_only:
            raise Exception(
                f"{self.__class__.__name__} was initialized for forward only. You can only call run_epoch()"
            )
        if self.epoch == 0:
            self.save(self.config.checkpoint_file(0))

        self.config.log("Starting training...")
        checkpoint_every = self.config.get("train.checkpoint.every")
        checkpoint_keep = self.config.get("train.checkpoint.keep")
        metric_name = self.config.get("valid.metric")
        patience = self.config.get("valid.early_stopping.patience")
        while True:
            # checking for model improvement according to metric_name
            # and do early stopping and keep the best checkpoint
            if (len(self.valid_trace) > 0
                    and self.valid_trace[-1]["epoch"] == self.epoch):
                best_index = Metric(self).best_index(
                    list(
                        map(lambda trace: trace[metric_name],
                            self.valid_trace)))
                if best_index == len(self.valid_trace) - 1:
                    self.save(self.config.checkpoint_file("best"))
                if (patience > 0 and len(self.valid_trace) > patience
                        and best_index < len(self.valid_trace) - patience):
                    self.config.log(
                        "Stopping early ({} did not improve over best result ".
                        format(metric_name) +
                        "in the last {} validation runs).".format(patience))
                    for f in self.early_stop_hooks:
                        f(self)
                    break
                if self.epoch > self.config.get(
                        "valid.early_stopping.threshold.epochs"):
                    achieved = self.valid_trace[best_index][metric_name]
                    target = self.config.get(
                        "valid.early_stopping.threshold.metric_value")
                    if Metric(self).better(target, achieved):
                        self.config.log(
                            "Stopping early ({} did not achieve threshold after {} epochs"
                            .format(metric_name, self.epoch))
                        for f in self.early_stop_hooks:
                            f(self)
                        break

            # check additional stop conditions
            done = False
            for f in self.early_stop_conditions:
                done = done or f(self)
            if done:
                break

            # should we stop?
            if self.epoch >= self.config.get("train.max_epochs"):
                self.config.log("Maximum number of epochs reached.")
                break

            # update learning rate if warmup is used
            if self.epoch < self._lr_warmup:
                for group in self.optimizer.param_groups:
                    group["lr"] = group["initial_lr"] * (self.epoch +
                                                         1) / self._lr_warmup

            # start a new epoch
            self.epoch += 1
            self.config.log("Starting epoch {}...".format(self.epoch))
            trace_entry = self.run_epoch()
            self.config.log("Finished epoch {}.".format(self.epoch))

            # update model metadata
            self.model.meta["train_job_trace_entry"] = self.trace_entry
            self.model.meta["train_epoch"] = self.epoch
            self.model.meta["train_config"] = self.config
            self.model.meta["train_trace_entry"] = trace_entry

            self.handle_running_checkpoint(checkpoint_every, checkpoint_keep)
            # validate
            lr_metric = None
            if (self.config.get("valid.every") > 0
                    and self.epoch % self.config.get("valid.every") == 0):
                self.handle_validation(metric_name)
            else:
                self.kge_lr_scheduler.step(lr_metric)

        self.trace(event="train_completed")
Пример #5
0
    def _run(self):
        # read search configurations and expand them to full configs
        search_configs = copy.deepcopy(self.config.get("manual_search.configurations"))
        all_keys = set()
        for i in range(len(search_configs)):
            search_config = search_configs[i]
            folder = search_config["folder"]
            del search_config["folder"]
            config = self.config.clone(folder)
            config.set("job.type", "train")
            config.options.pop("manual_search", None)  # could be large, don't copy
            flattened_search_config = Config.flatten(search_config)
            config.set_all(flattened_search_config)
            all_keys.update(flattened_search_config.keys())
            search_configs[i] = config

        # create folders for search configs (existing folders remain
        # unmodified)
        for config in search_configs:
            config.init_folder()

        # TODO find a way to create all indexes before running the jobs. The quick hack
        # below does not work becuase pytorch then throws a "too many open files" error
        # self.dataset.index("train_sp_to_o")
        # self.dataset.index("train_po_to_s")
        # self.dataset.index("valid_sp_to_o")
        # self.dataset.index("valid_po_to_s")
        # self.dataset.index("test_sp_to_o")
        # self.dataset.index("test_po_to_s")

        # now start running/resuming
        for i, config in enumerate(search_configs):
            task_arg = (self, i, config, len(search_configs), all_keys)
            self.submit_task(kge.job.search._run_train_job, task_arg)
        self.wait_task(concurrent.futures.ALL_COMPLETED)

        # if not running the jobs, stop here
        if not self.config.get("manual_search.run"):
            self.config.log("Skipping evaluation of results as requested by user.")
            return

        # collect results
        best_per_job = [None] * len(search_configs)
        best_metric_per_job = [None] * len(search_configs)
        for ibm in self.ready_task_results:
            i, best, best_metric = ibm
            best_per_job[i] = best
            best_metric_per_job[i] = best_metric

        # produce an overall summary
        self.config.log("Result summary:")
        metric_name = self.config.get("valid.metric")
        overall_best = None
        overall_best_metric = None
        for i in range(len(search_configs)):
            best = best_per_job[i]
            best_metric = best_metric_per_job[i]
            if not overall_best or Metric(self).better(
                best_metric, overall_best_metric
            ):
                overall_best = best
                overall_best_metric = best_metric
            self.config.log(
                "{}={:.3f} after {} epochs in folder {}".format(
                    metric_name, best_metric, best["epoch"], best["folder"]
                ),
                prefix="  ",
            )
        self.config.log("And the winner is:")
        self.config.log(
            "{}={:.3f} after {} epochs in folder {}".format(
                metric_name,
                overall_best_metric,
                overall_best["epoch"],
                overall_best["folder"],
            ),
            prefix="  ",
        )
        self.config.log("Best overall result:")
        self.trace(
            event="search_completed",
            echo=True,
            echo_prefix="  ",
            log=True,
            scope="search",
            **overall_best
        )
Пример #6
0
    def _run(self):

        # let's go
        trial_no = 0
        while trial_no < self.num_trials:
            gc.collect()
            self.config.log(
                "Registering trial {}/{}...".format(trial_no, self.num_trials - 1)
            )

            # determine next trial
            if trial_no >= len(self.parameters):
                # create a new trial
                parameters, trial_id = self.register_trial()
                if trial_id is None:
                    self.config.log(
                        "Cannot generate trial parameters. Will try again after a "
                        + "running trial has completed."
                    )
                else:
                    # remember the trial
                    self.trial_ids.append(trial_id)
                    self.parameters.append(parameters)
                    self.results.append(None)
                    self.config.log(
                        "Created trial {:05d} with parameters: {}".format(
                            trial_no, parameters
                        )
                    )
            else:
                # use the trial of a resumed run of this job
                parameters, trial_id = self.register_trial(self.parameters[trial_no])
                self.trial_ids.append(trial_id)
                self.config.log(
                    "Resumed trial {:05d} with parameters: {}".format(
                        trial_no, parameters
                    )
                )

            if trial_id is None:
                # couldn't generate a new trial since data is lacking; so wait for data
                self.wait_task()
            elif self.results[trial_no] is not None:
                # trial result is in checkpoint, use it (from prior run of this job)
                self.config.log(
                    "Registering trial {:05d} result: {}".format(
                        trial_no, self.results[trial_no]
                    )
                )
                self.register_trial_result(
                    self.trial_ids[trial_no],
                    self.parameters[trial_no],
                    self.results[trial_no],
                )
            else:  # trial_id is valid, but no result yet
                # create/resume job for trial
                folder = str("{:05d}".format(trial_no))
                config = self.config.clone(folder)
                config.set("job.type", "train")
                config.set_all(_process_deprecated_options(copy.deepcopy(parameters)))
                config.init_folder()

                # save checkpoint here so that trial is not lost
                # TODO make atomic (may corrupt good checkpoint when canceled!)
                self.save(self.config.checkpoint_file(1))

                # run or schedule the trial
                self.submit_task(
                    kge.job.search._run_train_job,
                    (self, trial_no, config, self.num_trials, list(parameters.keys())),
                )

            # on last iteration, wait for all running trials to complete
            if trial_id is not None and trial_no == self.num_trials - 1:
                self.wait_task(return_when=concurrent.futures.ALL_COMPLETED)

            # for each ready trial, store its results
            for ready_trial_no, ready_trial_best, _ in self.ready_task_results:
                if ready_trial_best is not None:
                    self.config.log(
                        "Registering trial {:05d} result: {}".format(
                            ready_trial_no, ready_trial_best["metric_value"]
                        )
                    )
                else:
                    # TODO: currently cannot distinguish failed trials from trials that
                    # haven't been run to completion. Both will have their entry in
                    # self.results set to None
                    self.config.log(
                        "Registering failed trial {:05d}".format(ready_trial_no)
                    )
                self.results[ready_trial_no] = ready_trial_best
                self.register_trial_result(
                    self.trial_ids[ready_trial_no],
                    self.parameters[ready_trial_no],
                    ready_trial_best,
                )

                # save checkpoint
                # TODO make atomic (may corrupt good checkpoint when canceled!)
                self.save(self.config.checkpoint_file(1))

            # clean up
            self.ready_task_results.clear()
            if trial_id is not None:
                # advance to next trial (unless we did not run this one)
                trial_no += 1

        # all done, output failed trials result
        failed_trials = [i for i in range(len(self.results)) if self.results[i] is None]
        self.config.log(
            "{} trials were successful, {} trials failed".format(
                len(self.results) - len(failed_trials), len(failed_trials)
            )
        )
        if len(failed_trials) > 0:
            self.config.log(
                "Failed trials: {}".format(
                    " ".join(["{:05d}".format(x) for x in failed_trials])
                )
            )

        # and best trial
        if len(failed_trials) != len(self.results):
            trial_metric_values = [
                Metric(self).worst() if result is None else result["metric_value"]
                for result in self.results
            ]
            best_trial_index = Metric(self).best_index(trial_metric_values)
            metric_name = self.results[best_trial_index]["metric_name"]
            self.config.log(
                "Best trial ({:05d}): {}={}".format(
                    best_trial_index, metric_name, trial_metric_values[best_trial_index]
                )
            )

            self.trace(
                even="search_completed",
                echo=True,
                echo_prefix="  ",
                log=True,
                scope="search",
                **self.results[best_trial_index]
            )