def run(self): # read grid search options range all_keys = [] all_keys_short = [] all_values = [] all_indexes = [] grid_configs = self.config.get("grid_search.parameters") for k, v in sorted(Config.flatten(grid_configs).items()): all_keys.append(k) short_key = k[k.rfind(".") + 1:] if "_" in short_key: # just keep first letter after each _ all_keys_short.append("".join( map(lambda s: s[0], short_key.split("_")))) else: # keep up to three letters all_keys_short.append(short_key[:3]) all_values.append(v) all_indexes.append(range(len(v))) # create search configs search_configs = [] for indexes in itertools.product(*all_indexes): # obtain values for changed parameters values = list( map(lambda ik: all_values[ik[0]][ik[1]], enumerate(list(indexes)))) # create search configuration and check whether correct dummy_config = self.config.clone() search_config = Config(load_default=False) search_config.options["folder"] = "_".join( map(lambda i: all_keys_short[i] + str(values[i]), range(len(values)))) for i, key in enumerate(all_keys): dummy_config.set(key, values[i]) # to test whether correct k/v pair search_config.set(key, values[i], create=True) # and remember it search_configs.append(search_config.options) # create configuration file of search job self.config.set("search.type", "manual") self.config.set("manual_search.configurations", search_configs) self.config.save(os.path.join(self.config.folder, "config.yaml")) # and run it if self.config.get("grid_search.run"): job = Job.create(self.config, self.dataset, parent_job=self) job.resume() job.run() else: self.config.log( "Skipping running of search job as requested by user...")
def _run_train_job(sicnk, device=None): """Runs a training job and returns the trace entry of its best validation result. Also takes are of appropriate tracing. """ search_job, train_job_index, train_job_config, train_job_count, trace_keys = sicnk try: # load the job if device is not None: train_job_config.set("job.device", device) search_job.config.log( "Starting training job {} ({}/{}) on device {}...".format( train_job_config.folder, train_job_index + 1, train_job_count, train_job_config.get("job.device"), )) checkpoint_file = get_checkpoint_file(train_job_config) if checkpoint_file is not None: checkpoint = load_checkpoint(checkpoint_file, train_job_config.get("job.device")) job = Job.create_from( checkpoint=checkpoint, new_config=train_job_config, dataset=search_job.dataset, parent_job=search_job, ) else: job = Job.create( config=train_job_config, dataset=search_job.dataset, parent_job=search_job, ) # process the trace entries to far (in case of a resumed job) metric_name = search_job.config.get("valid.metric") valid_trace = [] def copy_to_search_trace(job, trace_entry=None): if trace_entry is None: trace_entry = job.valid_trace[-1] trace_entry = copy.deepcopy(trace_entry) for key in trace_keys: # Process deprecated options to some extent. Support key renames, but # not value renames. actual_key = {key: None} _process_deprecated_options(actual_key) if len(actual_key) > 1: raise KeyError( f"{key} is deprecated but cannot be handled automatically" ) actual_key = next(iter(actual_key.keys())) value = train_job_config.get(actual_key) trace_entry[key] = value trace_entry["folder"] = os.path.split(train_job_config.folder)[1] metric_value = Trace.get_metric(trace_entry, metric_name) trace_entry["metric_name"] = metric_name trace_entry["metric_value"] = metric_value trace_entry["parent_job_id"] = search_job.job_id search_job.config.trace(**trace_entry) valid_trace.append(trace_entry) for trace_entry in job.valid_trace: copy_to_search_trace(None, trace_entry) # run the job (adding new trace entries as we go) # TODO make this less hacky (easier once integrated into SearchJob) from kge.job import ManualSearchJob if not isinstance( search_job, ManualSearchJob) or search_job.config.get("manual_search.run"): job.post_valid_hooks.append(copy_to_search_trace) job.run() else: search_job.config.log( "Skipping running of training job as requested by user.") return (train_job_index, None, None) # analyze the result search_job.config.log("Best result in this training job:") best = None best_metric = None for trace_entry in valid_trace: metric = trace_entry["metric_value"] if not best or Metric(search_job).better(metric, best_metric): best = trace_entry best_metric = metric # record the best result of this job best["child_job_id"] = best["job_id"] for k in ["job", "job_id", "type", "parent_job_id", "scope", "event"]: if k in best: del best[k] search_job.trace( event="search_completed", echo=True, echo_prefix=" ", log=True, scope="train", **best, ) # force releasing the GPU memory of the job to avoid memory leakage del job gc.collect() return (train_job_index, best, best_metric) except BaseException as e: search_job.config.log("Trial {:05d} failed: {}".format( train_job_index, repr(e))) if search_job.on_error == "continue": return (train_job_index, None, None) else: search_job.config.log( "Aborting search due to failure of trial {:05d}".format( train_job_index)) raise e
def run(self): torch_device = self.config.get("job.device") if self.config.get("job.device") == "cuda": torch_device = "cuda:0" if torch_device != "cpu": torch.cuda.set_device(torch_device) # seeds need to be set in every process set_seeds(self.config, self.rank) os.environ["MASTER_ADDR"] = self.config.get("job.distributed.master_ip") os.environ["MASTER_PORT"] = self.config.get("job.distributed.master_port") min_rank = get_min_rank(self.config) print("before init", self.rank + min_rank) dist.init_process_group( backend="gloo", init_method="env://", world_size=self.num_total_workers + min_rank, rank=self.rank + min_rank, timeout=datetime.timedelta(hours=6), ) worker_ranks = list(range(min_rank, self.num_total_workers+min_rank)) worker_group = dist.new_group(worker_ranks, timeout=datetime.timedelta(hours=6)) # create parameter server server = None if self.config.get("job.distributed.parameter_server") == "lapse": os.environ["DMLC_NUM_WORKER"] = "0" os.environ["DMLC_NUM_SERVER"] = str(self.num_total_workers) os.environ["DMLC_ROLE"] = "server" os.environ["DMLC_PS_ROOT_URI"] = self.config.get( "job.distributed.master_ip" ) os.environ["DMLC_PS_ROOT_PORT"] = self.config.get( "job.distributed.lapse_port" ) num_workers_per_server = 1 lapse.setup(self.num_keys, num_workers_per_server) server = lapse.Server(self.num_keys, self.embedding_dim + self.optimizer_dim) elif self.config.get("job.distributed.parameter_server") == "shared": server = self.parameters # create train-worker config, dataset and folder device_pool: list = self.config.get("job.device_pool") if len(device_pool) == 0: device_pool.append(self.config.get("job.device")) worker_id = self.rank config = deepcopy(self.config) config.set("job.device", device_pool[worker_id % len(device_pool)]) config.folder = os.path.join(self.config.folder, f"worker-{self.rank}") config.init_folder() dataset = deepcopy(self.dataset) parameter_client = KgeParameterClient.create( client_type=self.config.get("job.distributed.parameter_server"), server_id=0, client_id=worker_id + min_rank, embedding_dim=self.embedding_dim + self.optimizer_dim, server=server, num_keys=self.num_keys, num_meta_keys=self.num_meta_keys, worker_group=worker_group, ) # don't re-initialize the model after loading checkpoint init_for_load_only = self.checkpoint_name is not None job = Job.create( config=config, dataset=dataset, parameter_client=parameter_client, init_for_load_only=init_for_load_only, ) if self.checkpoint_name is not None: checkpoint = load_checkpoint(self.checkpoint_name) job._load(checkpoint) job.load_distributed(checkpoint_name=self.checkpoint_name) job.run() # all done, clean up print("shut down everything") parameter_client.barrier() if hasattr(job, "work_scheduler_client"): job.work_scheduler_client.shutdown() parameter_client.shutdown() # delete all occurrences of the parameter client to properly shutdown lapse # del job del job.parameter_client del job.model.get_s_embedder().parameter_client del job.model.get_p_embedder().parameter_client del job.model if hasattr(job, "optimizer"): del job.optimizer del parameter_client gc.collect() # make sure lapse-worker destructor is called # shutdown server if server is not None and type(server) != torch.Tensor: server.shutdown() if self.result_pipe is not None: if hasattr(job, "valid_trace"): # if we valid from checkpoint there is no valid trace self.result_pipe.send(job.valid_trace) else: self.result_pipe.send(None)
def main(): # default config config = Config() # now parse the arguments parser = create_parser(config) args, unknown_args = parser.parse_known_args() # If there where unknown args, add them to the parser and reparse. The correctness # of these arguments will be checked later. if len(unknown_args) > 0: parser = create_parser( config, filter(lambda a: a.startswith("--"), unknown_args) ) args = parser.parse_args() # process meta-commands process_meta_command(args, "create", {"command": "start", "run": False}) process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"}) process_meta_command( args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"} ) process_meta_command( args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"} ) # dump command if args.command == "dump": dump(args) exit() # package command if args.command == "package": package_model(args) exit() # start command if args.command == "start": # use toy config file if no config given if args.config is None: args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml" print( "WARNING: No configuration specified; using " + args.config, file=sys.stderr, ) if not vars(args)["console.quiet"]: print("Loading configuration {}...".format(args.config)) config.load(args.config) # resume command if args.command == "resume": if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"): args.config += "/config.yaml" if not vars(args)["console.quiet"]: print("Resuming from configuration {}...".format(args.config)) config.load(args.config) config.folder = os.path.dirname(args.config) if not config.folder: config.folder = "." if not os.path.exists(config.folder): raise ValueError( "{} is not a valid config file for resuming".format(args.config) ) # overwrite configuration with command line arguments for key, value in vars(args).items(): if key in [ "command", "config", "run", "folder", "checkpoint", "abort_when_cache_outdated", ]: continue if value is not None: if key == "search.device_pool": value = "".join(value).split(",") try: if isinstance(config.get(key), bool): value = argparse_bool_type(value) except KeyError: pass config.set(key, value) if key == "model": config._import(value) # initialize output folder if args.command == "start": if args.folder is None: # means: set default config_name = os.path.splitext(os.path.basename(args.config))[0] config.folder = os.path.join( kge_base_dir(), "local", "experiments", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name, ) else: config.folder = args.folder # catch errors to log them try: if args.command == "start" and not config.init_folder(): raise ValueError("output folder {} exists already".format(config.folder)) config.log("Using folder: {}".format(config.folder)) # determine checkpoint to resume (if any) if hasattr(args, "checkpoint"): checkpoint_file = get_checkpoint_file(config, args.checkpoint) # disable processing of outdated cached dataset files globally Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated # log configuration config.log("Configuration:") config.log(yaml.dump(config.options), prefix=" ") config.log("git commit: {}".format(get_git_revision_short_hash()), prefix=" ") # set random seeds def get_seed(what): seed = config.get(f"random_seed.{what}") if seed < 0 and config.get(f"random_seed.default") >= 0: import hashlib # we add an md5 hash to the default seed so that different PRNGs get a # different seed seed = ( config.get(f"random_seed.default") + int(hashlib.md5(what.encode()).hexdigest(), 16) ) % 0xFFFF # stay 32-bit return seed if get_seed("python") > -1: import random random.seed(get_seed("python")) if get_seed("torch") > -1: import torch torch.manual_seed(get_seed("torch")) if get_seed("numpy") > -1: import numpy.random numpy.random.seed(get_seed("numpy")) if get_seed("numba") > -1: import numpy as np, numba @numba.njit def seed_numba(seed): np.random.seed(seed) seed_numba(get_seed("numba")) # let's go if args.command == "start" and not args.run: config.log("Job created successfully.") else: # load data dataset = Dataset.create(config) # let's go if args.command == "resume": if checkpoint_file is not None: checkpoint = load_checkpoint( checkpoint_file, config.get("job.device") ) job = Job.create_from( checkpoint, new_config=config, dataset=dataset ) else: job = Job.create(config, dataset) job.config.log( "No checkpoint found or specified, starting from scratch..." ) else: job = Job.create(config, dataset) job.run() except BaseException: tb = traceback.format_exc() config.log(tb, echo=False) raise
def main(): # default config config = Config() # now parse the arguments parser = create_parser(config) args, unknown_args = parser.parse_known_args() # If there where unknown args, add them to the parser and reparse. The correctness # of these arguments will be checked later. if len(unknown_args) > 0: parser = create_parser( config, filter(lambda a: a.startswith("--"), unknown_args) ) args = parser.parse_args() # process meta-commands process_meta_command(args, "create", {"command": "start", "run": False}) process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"}) process_meta_command( args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"} ) process_meta_command( args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"} ) # dump command if args.command == "dump": dump(args) exit() # start command if args.command == "start": # use toy config file if no config given if args.config is None: args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml" print("WARNING: No configuration specified; using " + args.config) print("Loading configuration {}...".format(args.config)) config.load(args.config) # resume command if args.command == "resume": if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"): args.config += "/config.yaml" print("Resuming from configuration {}...".format(args.config)) config.load(args.config) config.folder = os.path.dirname(args.config) if not config.folder: config.folder = "." if not os.path.exists(config.folder): raise ValueError( "{} is not a valid config file for resuming".format(args.config) ) # overwrite configuration with command line arguments for key, value in vars(args).items(): if key in [ "command", "config", "run", "folder", "checkpoint", "abort_when_cache_outdated", ]: continue if value is not None: if key == "search.device_pool": value = "".join(value).split(",") try: if isinstance(config.get(key), bool): value = argparse_bool_type(value) except KeyError: pass config.set(key, value) if key == "model": config._import(value) # initialize output folder if args.command == "start": if args.folder is None: # means: set default config_name = os.path.splitext(os.path.basename(args.config))[0] config.folder = os.path.join( kge_base_dir(), "local", "experiments", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name, ) else: config.folder = args.folder # catch errors to log them try: if args.command == "start" and not config.init_folder(): raise ValueError("output folder {} exists already".format(config.folder)) config.log("Using folder: {}".format(config.folder)) # determine checkpoint to resume (if any) if hasattr(args, "checkpoint"): if args.checkpoint == "default": if config.get("job.type") in ["eval", "valid"]: checkpoint_file = config.checkpoint_file("best") else: checkpoint_file = None # means last elif is_number(args.checkpoint, int) or args.checkpoint == "best": checkpoint_file = config.checkpoint_file(args.checkpoint) else: # otherwise, treat it as a filename checkpoint_file = args.checkpoint # disable processing of outdated cached dataset files globally Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated # log configuration config.log("Configuration:") config.log(yaml.dump(config.options), prefix=" ") config.log("git commit: {}".format(get_git_revision_short_hash()), prefix=" ") # set random seeds if config.get("random_seed.python") > -1: import random random.seed(config.get("random_seed.python")) if config.get("random_seed.torch") > -1: import torch torch.manual_seed(config.get("random_seed.torch")) if config.get("random_seed.numpy") > -1: import numpy.random numpy.random.seed(config.get("random_seed.numpy")) # let's go if args.command == "start" and not args.run: config.log("Job created successfully.") else: # load data dataset = Dataset.load(config) # let's go job = Job.create(config, dataset) if args.command == "resume": job.resume(checkpoint_file) job.run() except BaseException as e: tb = traceback.format_exc() config.log(tb, echo=False) raise e from None