def copy_to_search_trace(job, trace_entry): trace_entry = copy.deepcopy(trace_entry) for key in trace_keys: # Process deprecated options to some extent. Support key renames, but # not value renames. actual_key = {key: None} _process_deprecated_options(actual_key) if len(actual_key) > 1: raise KeyError( f"{key} is deprecated but cannot be handled automatically" ) actual_key = next(iter(actual_key.keys())) value = train_job_config.get(actual_key) trace_entry[key] = value trace_entry["folder"] = os.path.split(train_job_config.folder)[1] metric_value = Trace.get_metric(trace_entry, metric_name) trace_entry["metric_name"] = metric_name trace_entry["metric_value"] = metric_value trace_entry["parent_job_id"] = search_job.job_id search_job.config.trace(**trace_entry) valid_trace.append(trace_entry)
def _dump_trace(args): """Execute the 'dump trace' command.""" if (args.train or args.valid or args.test or args.truncate or args.job_id or args.checkpoint or args.batch or args.example) and args.search: sys.exit( "--search and any of --train, --valid, --test, --truncate, --job_id," " --checkpoint, --batch, --example are mutually exclusive") entry_type_specified = True if not (args.train or args.valid or args.test or args.search): entry_type_specified = False args.train = True args.valid = True args.test = True truncate_flag = False truncate_epoch = None if isinstance(args.truncate, bool) and args.truncate: truncate_flag = True elif not isinstance(args.truncate, bool): if not args.truncate.isdigit(): sys.exit( "Integer argument or no argument for --truncate must be used") truncate_epoch = int(args.truncate) checkpoint_path = None if ".pt" in os.path.split(args.source)[-1]: checkpoint_path = args.source folder_path = os.path.split(args.source)[0] else: # determine job_id and epoch from last/best checkpoint automatically if args.checkpoint: checkpoint_path = Config.best_or_last_checkpoint_file(args.source) folder_path = args.source if not checkpoint_path and truncate_flag: sys.exit( "--truncate can only be used as a flag when a checkpoint is specified." " Consider specifying a checkpoint or use an integer argument for the" " --truncate option") if checkpoint_path and args.job_id: sys.exit( "--job_id cannot be used together with a checkpoint as the checkpoint" " already specifies the job_id") trace = os.path.join(folder_path, "trace.yaml") if not os.path.isfile(trace): sys.exit( f"No file 'trace.yaml' found at {os.path.abspath(folder_path)}") # process additional keys from --keys and --keysfile keymap = OrderedDict() additional_keys = [] if args.keysfile: with open(args.keysfile, "r") as keyfile: additional_keys = keyfile.readlines() if args.keys: additional_keys += args.keys for line in additional_keys: line = line.rstrip("\n").replace(" ", "") name_key = line.split("=") if len(name_key) == 1: name_key += name_key keymap[name_key[0]] = name_key[1] job_id = None # use job_id and truncate_epoch from checkpoint if checkpoint_path and truncate_flag: checkpoint = torch.load(f=checkpoint_path, map_location="cpu") job_id = checkpoint["job_id"] truncate_epoch = checkpoint["epoch"] # only use job_id from checkpoint elif checkpoint_path: checkpoint = torch.load(f=checkpoint_path, map_location="cpu") job_id = checkpoint["job_id"] # no checkpoint specified job_id might have been set manually elif args.job_id: job_id = args.job_id # don't restrict epoch number in case it has not been specified yet if not truncate_epoch: truncate_epoch = float("inf") entries, job_epochs = [], {} if not args.search: entries, job_epochs = Trace.grep_training_trace_entries( tracefile=trace, train=args.train, test=args.test, valid=args.valid, example=args.example, batch=args.batch, job_id=job_id, epoch_of_last=truncate_epoch, ) if not entries and (args.search or not entry_type_specified): entries = Trace.grep_entries(tracefile=trace, conjunctions=[f"scope: train"]) truncate_epoch = None if entries: args.search = True if not entries and entry_type_specified: sys.exit( "No relevant trace entries found. If this was a trace from a search" " job, dont use any of --train --valid --test.") elif not entries: sys.exit("No relevant trace entries found.") if args.list_keys: all_trace_keys = set() if not args.yaml: csv_writer = csv.writer(sys.stdout) # dict[new_name] = (lookup_name, where) # if where=="config"/"trace" it will be looked up automatically # if where=="sep" it must be added in in the write loop separately if args.no_default_keys: default_attributes = OrderedDict() else: default_attributes = OrderedDict([ ("job_id", ("job_id", "sep")), ("dataset", ("dataset.name", "config")), ("model", ("model", "sep")), ("reciprocal", ("reciprocal", "sep")), ("job", ("job", "sep")), ("job_type", ("type", "trace")), ("split", ("split", "sep")), ("epoch", ("epoch", "trace")), ("avg_loss", ("avg_loss", "trace")), ("avg_penalty", ("avg_penalty", "trace")), ("avg_cost", ("avg_cost", "trace")), ("metric_name", ("valid.metric", "config")), ("metric", ("metric", "sep")), ]) if args.search: default_attributes["child_folder"] = ("folder", "trace") default_attributes["child_job_id"] = ("child_job_id", "sep") if not (args.no_header or args.list_keys): csv_writer.writerow( list(default_attributes.keys()) + [key for key in keymap.keys()]) # store configs for job_id's s.t. they need to be loaded only once configs = {} warning_shown = False for entry in entries: current_epoch = entry.get("epoch") job_type = entry.get("job") job_id = entry.get("job_id") if truncate_epoch and not current_epoch <= float(truncate_epoch): continue # filter out entries not relevant to the unique training sequence determined # by the options; not relevant for search if job_type == "train": if current_epoch > job_epochs[job_id]: continue elif job_type == "eval": if "resumed_from_job_id" in entry: if current_epoch > job_epochs[entry.get( "resumed_from_job_id")]: continue elif "parent_job_id" in entry: if current_epoch > job_epochs[entry.get("parent_job_id")]: continue # find relevant config file child_job_id = entry.get( "child_job_id") if "child_job_id" in entry else None config_key = (entry.get("folder") + "/" + str(child_job_id) if args.search else job_id) if config_key in configs.keys(): config = configs[config_key] else: if args.search: if not child_job_id and not warning_shown: # This warning is from Dec 19, 2019. TODO remove print( "Warning: You are dumping the trace of an older search job. " "This is fine only if " "the config.yaml files in each subfolder have not been modified " "after running the corresponding training job.", file=sys.stderr, ) warning_shown = True config = get_config_for_job_id( child_job_id, os.path.join(folder_path, entry.get("folder"))) entry["type"] = config.get("train.type") else: config = get_config_for_job_id(job_id, folder_path) configs[config_key] = config if args.list_keys: all_trace_keys.update(entry.keys()) continue new_attributes = OrderedDict() # when training was reciprocal, use the base_model as model if config.get_default("model") == "reciprocal_relations_model": model = config.get_default( "reciprocal_relations_model.base_model.type") # the string that substitutes $base_model in keymap if it exists subs_model = "reciprocal_relations_model.base_model" reciprocal = 1 else: model = config.get_default("model") subs_model = model reciprocal = 0 # search for the additional keys from --keys and --keysfile for new_key in keymap.keys(): lookup = keymap[new_key] # search for special keys value = None if lookup == "$folder": value = os.path.abspath(folder_path) elif lookup == "$checkpoint" and checkpoint_path: value = os.path.abspath(checkpoint_path) elif lookup == "$machine": value = socket.gethostname() if "$base_model" in lookup: lookup = lookup.replace("$base_model", subs_model) # search for ordinary keys; start searching in trace entry then config if not value: value = entry.get(lookup) if not value: try: value = config.get_default(lookup) except: pass # value stays None; creates empty field in csv if value and isinstance(value, bool): value = 1 elif not value and isinstance(value, bool): value = 0 new_attributes[new_key] = value if not args.yaml: # find the actual values for the default attributes actual_default = default_attributes.copy() for new_key in default_attributes.keys(): lookup, where = default_attributes[new_key] if where == "config": actual_default[new_key] = config.get(lookup) elif where == "trace": actual_default[new_key] = entry.get(lookup) # keys with separate treatment # "split" in {train,test,valid} for the datatype # "job" in {train,eval,valid,search} if job_type == "train": if "split" in entry: actual_default["split"] = entry.get("split") else: actual_default["split"] = "train" actual_default["job"] = "train" elif job_type == "eval": if "split" in entry: actual_default["split"] = entry.get( "split") # test or valid else: # deprecated actual_default["split"] = entry.get( "data") # test or valid if entry.get("resumed_from_job_id"): actual_default["job"] = "eval" # from "kge eval" else: actual_default["job"] = "valid" # child of training job else: actual_default["job"] = job_type if "split" in entry: actual_default["split"] = entry.get("split") else: # deprecated actual_default["split"] = entry.get( "data") # test or valid actual_default["job_id"] = job_id.split("-")[0] actual_default["model"] = model actual_default["reciprocal"] = reciprocal # lookup name is in config value is in trace actual_default["metric"] = entry.get( config.get_default("valid.metric")) if args.search: actual_default["child_job_id"] = entry.get( "child_job_id").split("-")[0] for key in list(actual_default.keys()): if key not in default_attributes: del actual_default[key] csv_writer.writerow( [actual_default[new_key] for new_key in actual_default.keys()] + [new_attributes[new_key] for new_key in new_attributes.keys()]) else: entry.update({"reciprocal": reciprocal, "model": model}) if keymap: entry.update(new_attributes) print(entry) if args.list_keys: # only one config needed config = configs[list(configs.keys())[0]] options = Config.flatten(config.options) options = sorted(filter(lambda opt: "+++" not in opt, options), key=lambda opt: opt.lower()) if isinstance(args.list_keys, bool): sep = ", " else: sep = args.list_keys print("Default keys for CSV: ") print(*default_attributes.keys(), sep=sep) print("") print("Special keys: ") print(*["$folder", "$checkpoint", "$machine", "$base_model"], sep=sep) print("") print("Keys found in trace: ") print(*sorted(all_trace_keys), sep=sep) print("") print("Keys found in config: ") print(*options, sep=sep)
def _dump_trace(args): """ Executes the 'dump trace' command.""" start = time.time() if (args.train or args.valid or args.test) and args.search: print( "--search and --train, --valid, --test are mutually exclusive", file=sys.stderr, ) exit(1) entry_type_specified = True if not (args.train or args.valid or args.test or args.search): entry_type_specified = False args.train = True args.valid = True args.test = True checkpoint_path = None if ".pt" in os.path.split(args.source)[-1]: checkpoint_path = args.source folder_path = os.path.split(args.source)[0] else: # determine job_id and epoch from last/best checkpoint automatically if args.checkpoint: checkpoint_path = Config.get_best_or_last_checkpoint(args.source) folder_path = args.source if not args.checkpoint and args.truncate: raise ValueError( "You can only use --truncate when a checkpoint is specified." "Consider using --checkpoint or provide a checkpoint file as source" ) trace = os.path.join(folder_path, "trace.yaml") if not os.path.isfile(trace): sys.stderr.write("No trace found at {}\n".format(trace)) exit(1) keymap = OrderedDict() additional_keys = [] if args.keysfile: with open(args.keysfile, "r") as keyfile: additional_keys = keyfile.readlines() if args.keys: additional_keys += args.keys for line in additional_keys: line = line.rstrip("\n").replace(" ", "") name_key = line.split("=") if len(name_key) == 1: name_key += name_key keymap[name_key[0]] = name_key[1] job_id = None epoch = int(args.max_epoch) # use job_id and epoch from checkpoint if checkpoint_path and args.truncate: checkpoint = torch.load(f=checkpoint_path, map_location="cpu") job_id = checkpoint["job_id"] epoch = checkpoint["epoch"] # only use job_id from checkpoint elif checkpoint_path: checkpoint = torch.load(f=checkpoint_path, map_location="cpu") job_id = checkpoint["job_id"] # override job_id and epoch with user arguments if args.job_id: job_id = args.job_id if not epoch: epoch = float("inf") entries, job_epochs = [], {} if not args.search: entries, job_epochs = Trace.grep_training_trace_entries( tracefile=trace, train=args.train, test=args.test, valid=args.valid, example=args.example, batch=args.batch, job_id=job_id, epoch_of_last=epoch, ) if not entries and (args.search or not entry_type_specified): entries = Trace.grep_entries(tracefile=trace, conjunctions=[f"scope: train"]) epoch = None if entries: args.search = True if not entries: print("No relevant trace entries found.", file=sys.stderr) exit(1) middle = time.time() if not args.yaml: csv_writer = csv.writer(sys.stdout) # dict[new_name] = (lookup_name, where) # if where=="config"/"trace" it will be looked up automatically # if where=="sep" it must be added in in the write loop separately if args.no_default_keys: default_attributes = OrderedDict() else: default_attributes = OrderedDict([ ("job_id", ("job_id", "sep")), ("dataset", ("dataset.name", "config")), ("model", ("model", "sep")), ("reciprocal", ("reciprocal", "sep")), ("job", ("job", "sep")), ("job_type", ("type", "trace")), ("split", ("split", "sep")), ("epoch", ("epoch", "trace")), ("avg_loss", ("avg_loss", "trace")), ("avg_penalty", ("avg_penalty", "trace")), ("avg_cost", ("avg_cost", "trace")), ("metric_name", ("valid.metric", "config")), ("metric", ("metric", "sep")), ]) if args.search: default_attributes["child_folder"] = ("folder", "trace") default_attributes["child_job_id"] = ("child_job_id", "sep") if not args.no_header: csv_writer.writerow( list(default_attributes.keys()) + [key for key in keymap.keys()]) # store configs for job_id's s.t. they need to be loaded only once configs = {} warning_shown = False for entry in entries: if epoch and not entry.get("epoch") <= float(epoch): continue # filter out not needed entries from a previous job when # a job was resumed from the middle if entry.get("job") == "train": job_id = entry.get("job_id") if entry.get("epoch") > job_epochs[job_id]: continue # find relevant config file child_job_id = entry.get( "child_job_id") if "child_job_id" in entry else None config_key = (entry.get("folder") + "/" + str(child_job_id) if args.search else entry.get("job_id")) if config_key in configs.keys(): config = configs[config_key] else: if args.search: if not child_job_id and not warning_shown: # This warning is from Dec 19, 2019. TODO remove print( "Warning: You are dumping the trace of an older search job. " "This is fine only if " "the config.yaml files in each subfolder have not been modified " "after running the corresponding training job.", file=sys.stderr, ) warning_shown = True config = get_config_for_job_id( child_job_id, os.path.join(folder_path, entry.get("folder"))) entry["type"] = config.get("train.type") else: config = get_config_for_job_id(entry.get("job_id"), folder_path) configs[config_key] = config new_attributes = OrderedDict() if config.get_default("model") == "reciprocal_relations_model": model = config.get_default( "reciprocal_relations_model.base_model.type") # the string that substitutes $base_model in keymap if it exists subs_model = "reciprocal_relations_model.base_model" reciprocal = 1 else: model = config.get_default("model") subs_model = model reciprocal = 0 for new_key in keymap.keys(): lookup = keymap[new_key] if "$base_model" in lookup: lookup = lookup.replace("$base_model", subs_model) try: if lookup == "$folder": val = os.path.abspath(folder_path) elif lookup == "$checkpoint": val = os.path.abspath(checkpoint_path) elif lookup == "$machine": val = socket.gethostname() else: val = config.get_default(lookup) except: # creates empty field if key is not existing val = entry.get(lookup) if type(val) == bool and val: val = 1 elif type(val) == bool and not val: val = 0 new_attributes[new_key] = val if not args.yaml: # find the actual values for the default attributes actual_default = default_attributes.copy() for new_key in default_attributes.keys(): lookup, where = default_attributes[new_key] if where == "config": actual_default[new_key] = config.get(lookup) elif where == "trace": actual_default[new_key] = entry.get(lookup) # keys with separate treatment # "split" in {train,test,valid} for the datatype # "job" in {train,eval,valid,search} if entry.get("job") == "train": actual_default["split"] = "train" actual_default["job"] = "train" elif entry.get("job") == "eval": actual_default["split"] = entry.get("data") # test or valid if entry.get("resumed_from_job_id"): actual_default["job"] = "eval" # from "kge eval" else: actual_default["job"] = "valid" # child of training job else: actual_default["job"] = entry.get("job") actual_default["split"] = entry.get("data") actual_default["job_id"] = entry.get("job_id").split("-")[0] actual_default["model"] = model actual_default["reciprocal"] = reciprocal # lookup name is in config value is in trace actual_default["metric"] = entry.get( config.get_default("valid.metric")) if args.search: actual_default["child_job_id"] = entry.get( "child_job_id").split("-")[0] for key in list(actual_default.keys()): if key not in default_attributes: del actual_default[key] csv_writer.writerow( [actual_default[new_key] for new_key in actual_default.keys()] + [new_attributes[new_key] for new_key in new_attributes.keys()]) else: entry.update({"reciprocal": reciprocal, "model": model}) if keymap: entry.update(new_attributes) sys.stdout.write(re.sub("[{}']", "", str(entry))) sys.stdout.write("\n") end = time.time() if args.timeit: sys.stdout.write("Grep + processing took {} \n".format(middle - start)) sys.stdout.write("Writing took {}".format(end - middle))
def _run_train_job(sicnk, device=None): """Runs a training job and returns the trace entry of its best validation result. Also takes are of appropriate tracing. """ search_job, train_job_index, train_job_config, train_job_count, trace_keys = sicnk try: # load the job if device is not None: train_job_config.set("job.device", device) search_job.config.log( "Starting training job {} ({}/{}) on device {}...".format( train_job_config.folder, train_job_index + 1, train_job_count, train_job_config.get("job.device"), )) # process the trace entries to far (in case of a resumed job) metric_name = search_job.config.get("valid.metric") valid_trace = [] def copy_to_search_trace(job, trace_entry=None): if trace_entry is None: trace_entry = job.valid_trace[-1] trace_entry = copy.deepcopy(trace_entry) for key in trace_keys: # Process deprecated options to some extent. Support key renames, but # not value renames. actual_key = {key: None} _process_deprecated_options(actual_key) if len(actual_key) > 1: raise KeyError( f"{key} is deprecated but cannot be handled automatically" ) actual_key = next(iter(actual_key.keys())) value = train_job_config.get(actual_key) trace_entry[key] = value trace_entry["folder"] = os.path.split(train_job_config.folder)[1] metric_value = Trace.get_metric(trace_entry, metric_name) trace_entry["metric_name"] = metric_name trace_entry["metric_value"] = metric_value trace_entry["parent_job_id"] = search_job.job_id search_job.config.trace(**trace_entry) valid_trace.append(trace_entry) checkpoint_file = get_checkpoint_file(train_job_config) if checkpoint_file is not None: checkpoint = load_checkpoint(checkpoint_file, train_job_config.get("job.device")) job = Job.create_from( checkpoint=checkpoint, new_config=train_job_config, dataset=search_job.dataset, parent_job=search_job, ) for trace_entry in job.valid_trace: copy_to_search_trace(None, trace_entry) else: if train_job_config.get("job.distributed.num_workers") > 0: from kge.distributed.funcs import create_and_run_distributed valid_trace = create_and_run_distributed( config=train_job_config, dataset=search_job.dataset) else: job = Job.create( config=train_job_config, dataset=search_job.dataset, parent_job=search_job, ) if train_job_config.get("job.distributed.num_workers") <= 0: valid_trace = [] # run the job (adding new trace entries as we go) # TODO make this less hacky (easier once integrated into SearchJob) from kge.job import ManualSearchJob if not isinstance(search_job, ManualSearchJob ) or search_job.config.get("manual_search.run"): job.post_valid_hooks.append(copy_to_search_trace) job.run() else: search_job.config.log( "Skipping running of training job as requested by user.") return (train_job_index, None, None) # analyze the result search_job.config.log("Best result in this training job:") best = None best_metric = None for trace_entry in valid_trace: if train_job_config.get("job.distributed.num_workers") <= 0: metric = trace_entry["metric_value"] else: # we can not this via post valid hook in distributed setting # can not pickle function copy_to_search_trace metric = Trace.get_metric(trace_entry, metric_name) trace_entry["metric_value"] = metric trace_entry["metric_name"] = metric_name if not best or Metric(search_job).better(metric, best_metric): best = trace_entry best_metric = metric # record the best result of this job best["child_job_id"] = best["job_id"] for k in ["job", "job_id", "type", "parent_job_id", "scope", "event"]: if k in best: del best[k] search_job.trace( event="search_completed", echo=True, echo_prefix=" ", log=True, scope="train", **best, ) # force releasing the GPU memory of the job to avoid memory leakage del job gc.collect() return (train_job_index, best, best_metric) except BaseException as e: search_job.config.log("Trial {:05d} failed: {}".format( train_job_index, repr(e))) if search_job.on_error == "continue": return (train_job_index, None, None) else: search_job.config.log( "Aborting search due to failure of trial {:05d}".format( train_job_index)) raise e