def copy_to_search_trace(job, trace_entry): trace_entry = copy.deepcopy(trace_entry) for key in trace_keys: # Process deprecated options to some extent. Support key renames, but # not value renames. actual_key = {key: None} _process_deprecated_options(actual_key) if len(actual_key) > 1: raise KeyError( f"{key} is deprecated but cannot be handled automatically" ) actual_key = next(iter(actual_key.keys())) value = train_job_config.get(actual_key) trace_entry[key] = value trace_entry["folder"] = os.path.split(train_job_config.folder)[1] metric_value = Trace.get_metric(trace_entry, metric_name) trace_entry["metric_name"] = metric_name trace_entry["metric_value"] = metric_value trace_entry["parent_job_id"] = search_job.job_id search_job.config.trace(**trace_entry) valid_trace.append(trace_entry)
def _run_train_job(sicnk, device=None): """Runs a training job and returns the trace entry of its best validation result. Also takes are of appropriate tracing. """ search_job, train_job_index, train_job_config, train_job_count, trace_keys = sicnk try: # load the job if device is not None: train_job_config.set("job.device", device) search_job.config.log( "Starting training job {} ({}/{}) on device {}...".format( train_job_config.folder, train_job_index + 1, train_job_count, train_job_config.get("job.device"), )) # process the trace entries to far (in case of a resumed job) metric_name = search_job.config.get("valid.metric") valid_trace = [] def copy_to_search_trace(job, trace_entry=None): if trace_entry is None: trace_entry = job.valid_trace[-1] trace_entry = copy.deepcopy(trace_entry) for key in trace_keys: # Process deprecated options to some extent. Support key renames, but # not value renames. actual_key = {key: None} _process_deprecated_options(actual_key) if len(actual_key) > 1: raise KeyError( f"{key} is deprecated but cannot be handled automatically" ) actual_key = next(iter(actual_key.keys())) value = train_job_config.get(actual_key) trace_entry[key] = value trace_entry["folder"] = os.path.split(train_job_config.folder)[1] metric_value = Trace.get_metric(trace_entry, metric_name) trace_entry["metric_name"] = metric_name trace_entry["metric_value"] = metric_value trace_entry["parent_job_id"] = search_job.job_id search_job.config.trace(**trace_entry) valid_trace.append(trace_entry) checkpoint_file = get_checkpoint_file(train_job_config) if checkpoint_file is not None: checkpoint = load_checkpoint(checkpoint_file, train_job_config.get("job.device")) job = Job.create_from( checkpoint=checkpoint, new_config=train_job_config, dataset=search_job.dataset, parent_job=search_job, ) for trace_entry in job.valid_trace: copy_to_search_trace(None, trace_entry) else: if train_job_config.get("job.distributed.num_workers") > 0: from kge.distributed.funcs import create_and_run_distributed valid_trace = create_and_run_distributed( config=train_job_config, dataset=search_job.dataset) else: job = Job.create( config=train_job_config, dataset=search_job.dataset, parent_job=search_job, ) if train_job_config.get("job.distributed.num_workers") <= 0: valid_trace = [] # run the job (adding new trace entries as we go) # TODO make this less hacky (easier once integrated into SearchJob) from kge.job import ManualSearchJob if not isinstance(search_job, ManualSearchJob ) or search_job.config.get("manual_search.run"): job.post_valid_hooks.append(copy_to_search_trace) job.run() else: search_job.config.log( "Skipping running of training job as requested by user.") return (train_job_index, None, None) # analyze the result search_job.config.log("Best result in this training job:") best = None best_metric = None for trace_entry in valid_trace: if train_job_config.get("job.distributed.num_workers") <= 0: metric = trace_entry["metric_value"] else: # we can not this via post valid hook in distributed setting # can not pickle function copy_to_search_trace metric = Trace.get_metric(trace_entry, metric_name) trace_entry["metric_value"] = metric trace_entry["metric_name"] = metric_name if not best or Metric(search_job).better(metric, best_metric): best = trace_entry best_metric = metric # record the best result of this job best["child_job_id"] = best["job_id"] for k in ["job", "job_id", "type", "parent_job_id", "scope", "event"]: if k in best: del best[k] search_job.trace( event="search_completed", echo=True, echo_prefix=" ", log=True, scope="train", **best, ) # force releasing the GPU memory of the job to avoid memory leakage del job gc.collect() return (train_job_index, best, best_metric) except BaseException as e: search_job.config.log("Trial {:05d} failed: {}".format( train_job_index, repr(e))) if search_job.on_error == "continue": return (train_job_index, None, None) else: search_job.config.log( "Aborting search due to failure of trial {:05d}".format( train_job_index)) raise e