def __init__(self, filename, config_path, config, message_store): self.logger = get_logger() self.task_index = {t: i for i, t in enumerate(self.task_order)} self.message_store = message_store self.filename = filename self.filename_path = config_path self.run_config = config self.global_config = get_config() self.prefix = self.global_config["QUEUE"]["prefix"] + "_" + filename self.max_jobs = int(self.global_config["QUEUE"]["max_jobs"]) self.max_jobs_gpu = int(self.global_config["QUEUE"]["max_gpu_jobs"]) self.max_jobs_in_queue = int( self.global_config["QUEUE"]["max_jobs_in_queue"]) self.max_jobs_in_queue_gpu = int( self.global_config["QUEUE"]["max_gpu_jobs_in_queue"]) self.output_dir = os.path.join(get_output_dir(), self.filename) self.tasks = None self.num_jobs_queue = 0 self.num_jobs_queue_gpu = 0 self.start = None self.finish = None self.force_refresh = False
def __init__(self, filename, config): self.logger = get_logger() self.filename = filename self.run_config = config self.global_config = get_config() self.prefix = "PIP_" + filename self.output_dir = None
def __init__(self, filename, config_path, config_raw, config, message_store): self.logger = get_logger() self.task_index = {t: i for i, t in enumerate(self.task_order)} self.message_store = message_store self.filename = filename self.filename_path = config_path self.file_raw = config_raw self.run_config = config self.global_config = get_config() self.prefix = self.global_config["QUEUE"]["prefix"] + "_" + filename self.max_jobs = int(self.global_config["QUEUE"]["max_jobs"]) self.max_jobs_gpu = int(self.global_config["QUEUE"]["max_gpu_jobs"]) self.max_jobs_in_queue = int( self.global_config["QUEUE"]["max_jobs_in_queue"]) self.max_jobs_in_queue_gpu = int( self.global_config["QUEUE"]["max_gpu_jobs_in_queue"]) self.logger.debug(self.global_config.keys()) self.sbatch_cpu_path = get_data_loc( self.global_config["SBATCH"]["cpu_location"]) with open(self.sbatch_cpu_path, 'r') as f: self.sbatch_cpu_header = f.read() self.sbatch_gpu_path = get_data_loc( self.global_config["SBATCH"]["gpu_location"]) with open(self.sbatch_gpu_path, 'r') as f: self.sbatch_gpu_header = f.read() self.sbatch_cpu_header = self.clean_header(self.sbatch_cpu_header) self.sbatch_gpu_header = self.clean_header(self.sbatch_gpu_header) self.setup_task_location = self.global_config["SETUP"]["location"] self.load_task_setup() self.output_dir = os.path.join(get_output_dir(), self.filename) self.tasks = None self.num_jobs_queue = 0 self.num_jobs_queue_gpu = 0 self.start = None self.finish = None self.force_refresh = False self.force_ignore_stage = None self.running = [] self.done = [] self.failed = [] self.blocked = []
def __init__(self, filename, config, message_store): self.logger = get_logger() self.message_store = message_store self.filename = filename self.run_config = config self.global_config = get_config() self.prefix = self.global_config["GLOBAL"]["prefix"] + "_" + filename self.max_jobs = int(self.global_config["GLOBAL"]["max_jobs"]) self.max_jobs_in_queue = int(self.global_config["GLOBAL"]["max_jobs_in_queue"]) self.output_dir = os.path.abspath(os.path.dirname(inspect.stack()[0][1]) + "/../" + self.global_config['OUTPUT']['output_dir'] + "/" + self.filename) self.tasks = None self.start = None self.finish = None self.force_refresh = False
def setup_logging(config_filename, logging_folder, args): level = logging.DEBUG if args.verbose else logging.INFO logging_filename = f"{logging_folder}/{config_filename}.log" message_store = MessageStore() message_store.setLevel(logging.WARNING) NOTICE_LEVELV_NUM = 25 logging.addLevelName(NOTICE_LEVELV_NUM, "NOTICE") def notice(self, message, *args, **kws): if self.isEnabledFor(NOTICE_LEVELV_NUM): self._log(NOTICE_LEVELV_NUM, message, args, **kws) logging.Logger.notice = notice fmt_verbose = "[%(levelname)8s |%(filename)21s:%(lineno)3d] %(message)s" fmt_info = "%(message)s" fmt = fmt_verbose if args.verbose else fmt_info logger = get_logger() handlers = [message_store] if not args.check: handlers.append(logging.FileHandler(logging_filename, mode="w")) handlers[-1].setLevel(logging.DEBUG) handlers[-1].setFormatter(logging.Formatter(fmt_verbose)) #logging.basicConfig(level=level, format=fmt, handlers=handlers) for h in handlers: logger.addHandler(h) coloredlogs.install( level=level, fmt=fmt, reconfigure=True, level_styles=coloredlogs.parse_encoded_styles( "debug=8;notice=green;warning=yellow;error=red,bold;critical=red,inverse" ), ) logging.getLogger("matplotlib").setLevel(logging.ERROR) logger.info(f"Logging streaming out, also saving to {logging_filename}") return message_store, logging_filename
def __init__(self, name, output_dir, dependencies=None): self.logger = get_logger() self.name = name self.output_dir = output_dir self.num_jobs = 1 if dependencies is None: dependencies = [] self.dependencies = dependencies self.hash = None self.output = { "name": name, "output_dir": output_dir } self.hash_file = os.path.join(self.output_dir, "hash.txt") self.done_file = os.path.join(self.output_dir, "done.txt") self.start_time = None self.end_time = None self.wall_time = None self.stage = None
class Task(ABC): FINISHED_SUCCESS = -1 FINISHED_FAILURE = -9 logger = get_logger() def __init__(self, name, output_dir, dependencies=None, config=None, done_file="done.txt"): self.name = name self.output_dir = output_dir self.num_jobs = 1 if dependencies is None: dependencies = [] self.dependencies = dependencies if config is None: config = {} self.config = copy.deepcopy(config) self.output = {} # Determine if this is an external (already done) job or not external_dirs = self.config.get("EXTERNAL_DIRS", []) external_names = [os.path.basename(d) for d in external_dirs] external_map = self.config.get("EXTERNAL_MAP", {}) output_name = os.path.basename(output_dir) name_match = external_map.get(output_name) if external_dirs: # This will only trigger if EXTERNAL_MAP is defined and output_name is in external_map if name_match is not None: matching_dirs = [d for d in external_dirs if name_match in d] if len(matching_dirs) == 0: self.logger.error( f"Task {output_name} has external mapping {name_match} but there were no matching EXTERNAL_DIRS" ) else: if len(matching_dirs) > 1: self.logger.warning( f"Task {output_name} has external mapping {name_match} which matched with multiple EXTERNAL_DIRS: {matching_dirs}. Defaulting to {matching_dirs[0]}" ) self.logger.info(f"Found external match for {output_name}") self.config["EXTERNAL"] = matching_dirs[0] # If you haven't specified an EXTERNAL_MAP for this output_name, check for exact match elif output_name in external_names: self.config["EXTERNAL"] = external_dirs[external_names.index( output_name)] else: self.logger.info(f"No external match found for {output_name}") self.external = self.config.get("EXTERNAL") if self.external is not None: self.logger.debug(f"External config stated to be {self.external}") self.external = get_data_loc(self.external) # External directory might be compressed if not os.path.exists(self.external): self.logger.warning( f"External config {self.external} does not exist, checking if it's compressed" ) compressed_dir = self.external + ".tar.gz" if not os.path.exists(compressed_dir): self.logger.error( f"{self.external} and {compressed_dir} do not exist") else: self.external = compressed_dir self.logger.debug( f"External config file path resolved to {self.external}" ) with tarfile.open(self.external, "r:gz") as tar: for member in tar: if member.isfile(): filename = os.path.basename(member.name) if filename != "config.yml": continue with tar.extractfile(member) as f: external_config = yaml.load( f, Loader=yaml.Loader) conf = external_config.get("CONFIG", {}) conf.update(self.config) self.config = conf self.output = external_config.get( "OUTPUT", {}) self.logger.debug( "Loaded external config successfully") else: if os.path.isdir(self.external): self.external = os.path.join(self.external, "config.yml") self.logger.debug( f"External config file path resolved to {self.external}") with open(self.external, "r") as f: external_config = yaml.load(f, Loader=yaml.Loader) conf = external_config.get("CONFIG", {}) conf.update(self.config) self.config = conf self.output = external_config.get("OUTPUT", {}) self.logger.debug("Loaded external config successfully") self.hash = None self.hash_file = os.path.join(self.output_dir, "hash.txt") self.done_file = os.path.join(self.output_dir, done_file) # Info about the job run self.start_time = None self.end_time = None self.wall_time = None self.stage = None self.fresh_run = True self.num_empty = 0 self.num_empty_threshold = 10 self.display_threshold = 0 self.gpu = False self.force_refresh = False self.force_ignore = False self.output.update({ "name": name, "output_dir": output_dir, "hash_file": self.hash_file, "done_file": self.done_file }) self.config_file = os.path.join(output_dir, "config.yml") def set_force_refresh(self, force_refresh): self.force_refresh = force_refresh def set_force_ignore(self, force_ignore): self.force_ignore = force_ignore def set_setup(self, setup): self.task_setup = setup def set_sbatch_cpu_header(self, header): self.logger.debug("Set cpu header") self.sbatch_cpu_header = header def set_sbatch_gpu_header(self, header): self.logger.debug("Set gpu header") self.sbatch_gpu_header = header def update_setup(self, setup_dict, task_setup): return task_setup.format(**setup_dict) def update_header(self, header_dict): for key, value in header_dict.items(): if key in self.sbatch_header: self.sbatch_header = self.sbatch_header.replace( key, str(value)) append_list = header_dict.get("APPEND") if append_list is not None: lines = self.sbatch_header.split('\n') lines += append_list self.sbatch_header = '\n'.join(lines) self.logger.debug("Updated header") def clean_header(self, header): lines = header.split('\n') mask = lambda x: (len(x) > 0) and (x[0] == '#') and ('xxxx' not in x) lines = filter(mask, lines) header = '\n'.join(lines) return header def compress(self): if os.path.exists(self.output_dir): output_file = self.output_dir + ".tar.gz" compress_dir(output_file, self.output_dir) for t in self.dependencies: if os.path.exists(t.output_dir): output_file = t.output_dir + ".tar.gz" compress_dir(output_file, t.output_dir) def uncompress(self): source_file = self.output_dir + ".tar.gz" if os.path.exists(source_file): uncompress_dir(os.path.dirname(self.output_dir), source_file) for t in self.dependencies: source_file = t.output_dir + ".tar.gz" if os.path.exists(source_file): uncompress_dir(os.path.dirname(t.output_dir), source_file) def _check_regenerate(self, new_hash): hash_are_different = new_hash != self.get_old_hash() if self.force_ignore: if hash_are_different: self.logger.warning( f"Warning, hashes are different for {self}, but force_ignore is True so regenerate=False" ) else: self.logger.debug( "Hashes agree and force_ignore is set, returning regenerate=False" ) return False elif self.force_refresh: self.logger.debug( "Force refresh is set, returning regenerate=True") return True else: if hash_are_different: self.logger.debug(f"Hashes are different, regenerating") return True else: self.logger.debug(f"Hashes are the same, not regenerating") return False def write_config(self): content = {"CONFIG": self.config, "OUTPUT": self.output} with open(self.config_file, "w") as f: yaml.dump(content, f, sort_keys=False) def load_config(self): with open(self.config_file, "r") as f: content = yaml.safe_load(f) return content def clear_config(self): if os.path.exists(self.config_file): os.remove(self.config_file) def clear_hash(self): if os.path.exists(self.hash_file): os.remove(self.hash_file) self.clear_config() def check_for_job(self, squeue, match): if squeue is None: return self.num_jobs num_jobs = len([i for i in squeue if match in i]) if num_jobs == 0: self.num_empty += 1 if self.num_empty >= self.num_empty_threshold: self.logger.error( f"No more waiting, there are no slurm jobs active that match {match}! Debug output dir {self.output_dir}" ) return Task.FINISHED_FAILURE elif self.num_empty > 1 and self.num_empty > self.display_threshold: self.logger.warning( f"Task {str(self)} has no match for {match} in squeue, warning {self.num_empty}/{self.num_empty_threshold}" ) return 0 return num_jobs def should_be_done(self): self.fresh_run = False def set_stage(self, stage): self.stage = stage def get_old_hash(self, quiet=False, required=False): if os.path.exists(self.hash_file): with open(self.hash_file, "r") as f: old_hash = f.read().strip() if not quiet: self.logger.debug( f"Previous result found, hash is {old_hash}") return old_hash else: if required: self.logger.error( f"No hash found for {self} in {self.hash_file}") else: self.logger.debug(f"No hash found for {self}") return "_NONE_" def get_hash_from_files(self, output_files): string_to_hash = "" for file in output_files: with open(file, "r") as f: string_to_hash += f.read() new_hash = self.get_hash_from_string(string_to_hash) return new_hash def get_hash_from_string(self, string_to_hash): hashes = sorted([ dep.get_old_hash(quiet=True, required=True) for dep in self.dependencies ]) string_to_hash += " ".join(hashes) new_hash = get_hash(string_to_hash) self.logger.debug(f"Current hash set to {new_hash}") return new_hash def save_new_hash(self, new_hash): with open(self.hash_file, "w") as f: f.write(str(new_hash)) self.logger.debug(f"New hash {new_hash}") self.logger.debug(f"New hash saved to {self.hash_file}") def set_num_jobs(self, num_jobs): self.num_jobs = num_jobs def add_dependency(self, task): self.dependencies.append(task) def run(self): self.uncompress() if self.external is not None: self.logger.debug(f"Name: {self.name} External: {self.external}") if os.path.exists(self.output_dir) and not self.force_refresh: self.logger.info( f"Not copying external site, output_dir already exists at {self.output_dir}" ) else: if os.path.exists(self.output_dir): self.logger.debug( f"Removing old directory {self.output_dir}") shutil.rmtree(self.output_dir, ignore_errors=True) if ".tar.gz" in self.external: tardir = os.path.basename(self.external).replace( ".tar.gz", "") self.logger.info( f"Copying files from {self.external} to {self.output_dir}" ) shutil.copyfile(self.external, self.output_dir + '.tar.gz') self.uncompress() shutil.move( os.path.join(os.path.dirname(self.output_dir), tardir), self.output_dir) else: self.logger.info( f"Copying from {os.path.dirname(self.external)} to {self.output_dir}" ) shutil.copytree(os.path.dirname(self.external), self.output_dir, symlinks=True) return True return self._run() def scan_file_for_error(self, path, *error_match, max_lines=10): assert len( error_match ) >= 1, "You need to specify what string to search for. I have nothing." found = False if not os.path.exists(path): self.logger.warning( f"Note, expected log path {path} does not exist") return False with open(path) as f: for i, line in enumerate(f.read().splitlines()): error_found = np.any([e in line for e in error_match]) if error_found: index = i found = True self.logger.error( f"Found error in file {path}, excerpt below") if found and i - index <= max_lines: self.logger.error(f"Excerpt: {line}") return found def scan_files_for_error(self, paths, *error_match, max_lines=10, max_erroring_files=3): num_errors = 0 self.logger.debug(f"Found {len(paths)} to scan") for path in paths: if "FAIL_SUMMARY.LOG" in path.upper(): self.logger.debug(f"Found {path}, loading in YAML contents") fail_summary = read_yaml(path) for key, dicts in fail_summary.items(): if key.startswith("FAILURE-0"): self.logger.error( f"{key}: {' '.join(dicts.get('ABORT_MESSAGES', 'Unknown message'))}" ) self.logger.error( f"{key}: Detailed in {dicts.get('JOB_LOG_FILE', 'Unknown path')}" ) num_errors += 1 if num_errors > max_erroring_files: break else: self.logger.debug(f"Scanning {path} for error") if self.scan_file_for_error(path, *error_match, max_lines=max_lines): num_errors += 1 if num_errors >= max_erroring_files: break return num_errors > 0 @staticmethod def match_tasks(mask, deps, match_none=True, allowed_failure=False): if mask is None: if match_none: mask = "" else: return [] if isinstance(mask, str): if mask == "*": mask = "" mask = ensure_list(mask) matching_deps = [d for d in deps if any(x in d.name for x in mask)] for m in mask: specific_match = [d for d in matching_deps if m in d.name] if len(specific_match) == 0 and not allowed_failure: Task.fail_config( f"Mask '{m}' does not match any deps. Probably a typo. Available options are {deps}" ) return matching_deps @staticmethod def match_tasks_of_type(mask, deps, *cls, match_none=True, allowed_failure=False): return Task.match_tasks(mask, Task.get_task_of_type(deps, *cls), match_none=match_none, allowed_failure=allowed_failure) @abstractmethod def _run(self): """ Execute the primary function of the task :param force_refresh: to force refresh and rerun - do not pass hash checks :return: true or false if the job launched successfully """ pass @staticmethod def get_task_of_type(tasks, *cls): return [t for t in tasks if isinstance(t, tuple(cls))] @staticmethod def fail_config(message): Task.logger.error(message) raise ValueError(f"Task failed config") @staticmethod @abstractmethod def get_tasks(config, prior_tasks, base_output_dir, stage_number, prefix, global_config): raise NotImplementedError() def get_wall_time_str(self): if self.end_time is not None and self.start_time is not None: return str(datetime.timedelta(seconds=self.wall_time)) return None def check_completion(self, squeue): """ Checks if the job has completed. Invokes `_check_completion` and determines wall time. :return: Task.FINISHED_SUCCESS, Task.FNISHED_FAILURE or the number of jobs still running """ result = self._check_completion(squeue) if result in [Task.FINISHED_SUCCESS, Task.FINISHED_FAILURE]: if os.path.exists(self.done_file): self.end_time = os.path.getmtime(self.done_file) if self.start_time is None and os.path.exists(self.hash_file): self.start_time = os.path.getmtime(self.hash_file) if self.end_time is not None and self.start_time is not None: self.wall_time = int(self.end_time - self.start_time + 0.5) # round up self.logger.info( f"Task finished with wall time {self.get_wall_time_str()}" ) if result == Task.FINISHED_FAILURE: self.clear_hash() elif not self.fresh_run: self.logger.error( "Hash check had passed, so the task should be done, but it said it wasn't!" ) self.logger.error( f"This means it probably crashed, have a look in {self.output_dir}" ) self.logger.error(f"Removing hash from {self.hash_file}") self.clear_hash() return Task.FINISHED_FAILURE if self.external is None and result == Task.FINISHED_SUCCESS and not os.path.exists( self.config_file): self.write_config() return result @abstractmethod def _check_completion(self, squeue): """ Checks if the job is complete or has failed. If it is complete it should also load in the any useful results that other tasks may need in `self.output` dictionary Such as the location of a trained model or output files. :param squeue: """ pass def __str__(self): wall_time = self.get_wall_time_str() if wall_time is not None: extra = f"wall time {wall_time}, " else: extra = "" if len(self.dependencies) > 5: deps = f"{[d.name for d in self.dependencies[:5]]} + {len(self.dependencies) - 5} more deps" else: deps = f"{[d.name for d in self.dependencies]}" if self.external is None: return f"{self.__class__.__name__} {self.name} task ({extra}{self.num_jobs} jobs, deps {deps})" else: return f"{self.__class__.__name__} {self.name} task (EXTERNAL JOB, deps {deps})" def __repr__(self): return self.__str__() def get_dep(self, *clss, fail=False): for d in self.dependencies: for cls in clss: if isinstance(d, cls): return d if fail: raise ValueError(f"No deps have class of type {clss}") return None def get_deps(self, *clss): return [d for d in self.dependencies if isinstance(d, tuple(clss))]
format=fmt, handlers=[ logging.FileHandler(logging_filename), logging.StreamHandler(), message_store, ]) coloredlogs.install( level=level, fmt=fmt, reconfigure=True, level_styles=coloredlogs.parse_encoded_styles( 'debug=8;notice=green;warning=yellow;error=red,bold;critical=red,inverse' )) logging.getLogger('matplotlib').setLevel(logging.ERROR) logger = get_logger() logger.info(f"Logging streaming out, also saving to {logging_filename}") # Load YAML config file config_path = os.path.dirname(inspect.stack()[0][1]) + args.config assert os.path.exists(config_path), f"File {config_path} cannot be found." with open(config_path, "r") as f: config = yaml.safe_load(f) manager = Manager(config_filename, config, message_store) if args.start is not None: args.refresh = True manager.set_start(args.start) manager.set_finish(args.finish) manager.set_force_refresh(args.refresh) manager.execute()
class Task(ABC): FINISHED_SUCCESS = -1 FINISHED_FAILURE = -9 logger = get_logger() def __init__(self, name, output_dir, dependencies=None): self.name = name self.output_dir = output_dir self.num_jobs = 1 if dependencies is None: dependencies = [] self.dependencies = dependencies self.hash = None self.output = {"name": name, "output_dir": output_dir} self.hash_file = os.path.join(self.output_dir, "hash.txt") self.done_file = os.path.join(self.output_dir, "done.txt") self.start_time = None self.end_time = None self.wall_time = None self.stage = None self.fresh_run = True self.num_empty = 0 self.num_empty_threshold = 10 self.display_threshold = 0 self.gpu = False def check_for_job(self, squeue, match): if squeue is None: return self.num_jobs num_jobs = len([i for i in squeue if match in i]) if num_jobs == 0: self.num_empty += 1 if self.num_empty >= self.num_empty_threshold: self.logger.error( f"No more waiting, there are no slurm jobs active that match {match}! Debug output dir {self.output_dir}" ) return Task.FINISHED_FAILURE elif self.num_empty > 1 and self.num_empty > self.display_threshold: self.logger.warning( f"Task {str(self)} has no match for {match} in squeue, warning {self.num_empty}/{self.num_empty_threshold}" ) return 0 return num_jobs def should_be_done(self): self.fresh_run = False def set_stage(self, stage): self.stage = stage def get_old_hash(self, quiet=False, required=False): if os.path.exists(self.hash_file): with open(self.hash_file, "r") as f: old_hash = f.read().strip() if not quiet: self.logger.debug( f"Previous result found, hash is {old_hash}") return old_hash else: if required: self.logger.error(f"No hash found for {self}") else: self.logger.debug(f"No hash found for {self}") return None def get_hash_from_files(self, output_files): string_to_hash = "" for file in output_files: with open(file, "r") as f: string_to_hash += f.read() new_hash = self.get_hash_from_string(string_to_hash) return new_hash def get_hash_from_string(self, string_to_hash): hashes = sorted([ dep.get_old_hash(quiet=True, required=True) for dep in self.dependencies ]) string_to_hash += " ".join(hashes) new_hash = get_hash(string_to_hash) self.logger.debug(f"Current hash set to {new_hash}") return new_hash def save_new_hash(self, new_hash): with open(self.hash_file, "w") as f: f.write(str(new_hash)) self.logger.debug(f"New hash {new_hash}") self.logger.debug(f"New hash saved to {self.hash_file}") def set_num_jobs(self, num_jobs): self.num_jobs = num_jobs def add_dependency(self, task): self.dependencies.append(task) def run(self, force_refresh): return self._run(force_refresh) def scan_file_for_error(self, path, *error_match, max_lines=10): assert len( error_match ) >= 1, "You need to specify what string to search for. I have nothing." found = False if not os.path.exists(path): self.logger.warning( f"Note, expected log path {path} does not exist") return False with open(path) as f: for i, line in enumerate(f.read().splitlines()): error_found = np.any([e in line for e in error_match]) index = i if error_found: found = True self.logger.error( f"Found error in file {path}, excerpt below") if found and i - index <= max_lines: self.logger.error(f"Excerpt: {line}") return found def scan_files_for_error(self, paths, *error_match, max_lines=10, max_erroring_files=3): num_errors = 0 self.logger.debug(f"Found {len(paths)} to scan") for path in paths: self.logger.debug(f"Scanning {path} for error") if self.scan_file_for_error(path, *error_match, max_lines=max_lines): num_errors += 1 if num_errors >= max_erroring_files: break return num_errors > 0 @staticmethod def match_tasks(mask, deps, match_none=True): if mask is None: if match_none: mask = "" else: return [] if isinstance(mask, str): if mask == "*": mask = "" mask = ensure_list(mask) matching_deps = [d for d in deps if any(x in d.name for x in mask)] for m in mask: specific_match = [d for d in matching_deps if m in d.name] if len(specific_match) == 0: Task.fail_config( f"Mask '{m}' does not match any deps. Probably a typo. Available options are {deps}" ) return matching_deps @staticmethod def match_tasks_of_type(mask, deps, *cls, match_none=True): return Task.match_tasks(mask, Task.get_task_of_type(deps, *cls), match_none=match_none) @abstractmethod def _run(self, force_refresh): """ Execute the primary function of the task :param force_refresh: to force refresh and rerun - do not pass hash checks :return: true or false if the job launched successfully """ pass @staticmethod def get_task_of_type(tasks, *cls): return [t for t in tasks if isinstance(t, tuple(cls))] @staticmethod def fail_config(message): Task.logger.error(message) raise ValueError(f"Task failed config") @staticmethod @abstractmethod def get_tasks(config, prior_tasks, base_output_dir, stage_number, prefix, global_config): raise NotImplementedError() def get_wall_time_str(self): if self.end_time is not None and self.start_time is not None: return str(datetime.timedelta(seconds=self.wall_time)) return None def check_completion(self, squeue): """ Checks if the job has completed. Invokes `_check_completion` and determines wall time. :return: Task.FINISHED_SUCCESS, Task.FNISHED_FAILURE or the number of jobs still running """ result = self._check_completion(squeue) if result in [Task.FINISHED_SUCCESS, Task.FINISHED_FAILURE]: if os.path.exists(self.done_file): self.end_time = os.path.getmtime(self.done_file) if self.start_time is None and os.path.exists(self.hash_file): self.start_time = os.path.getmtime(self.hash_file) if self.end_time is not None and self.start_time is not None: self.wall_time = int(self.end_time - self.start_time + 0.5) # round up self.logger.info( f"Task finished with wall time {self.get_wall_time_str()}" ) if result == Task.FINISHED_FAILURE and os.path.exists( self.hash_file): os.remove(self.hash_file) elif not self.fresh_run: self.logger.error( "Hash check had passed, so the task should be done, but it said it wasn't!" ) self.logger.error( f"This means it probably crashed, have a look in {self.output_dir}" ) self.logger.error(f"Removing hash from {self.hash_file}") if os.path.exists(self.hash_file): os.remove(self.hash_file) return Task.FINISHED_FAILURE return result @abstractmethod def _check_completion(self, squeue): """ Checks if the job is complete or has failed. If it is complete it should also load in the any useful results that other tasks may need in `self.output` dictionary Such as the location of a trained model or output files. :param squeue: """ pass def __str__(self): wall_time = self.get_wall_time_str() if wall_time is not None: extra = f"wall time {wall_time}, " else: extra = "" if len(self.dependencies) > 5: deps = f"{[d.name for d in self.dependencies[:5]]} + {len(self.dependencies) - 5} more deps" else: deps = f"{[d.name for d in self.dependencies]}" return f"{self.__class__.__name__} {self.name} task ({extra}{self.num_jobs} jobs, deps {deps})" def __repr__(self): return self.__str__() def get_dep(self, *clss, fail=False): for d in self.dependencies: for cls in clss: if isinstance(d, cls): return d if fail: raise ValueError(f"No deps have class of type {clss}") return None def get_deps(self, *clss): return [d for d in self.dependencies if isinstance(d, tuple(clss))]
def __init__(self, output_dir): self.logger = get_logger() self.output_dir = output_dir mkdirs(self.output_dir)