Example #1
0
        def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any:
            # Check it config was passed.
            if cfg_passthrough is not None:
                return task_function(cfg_passthrough)
            else:
                args = get_args_parser()

                # no return value from run_hydra() as it may sometime actually run the task_function
                # multiple times (--multirun)
                _run_hydra(
                    args_parser=args,
                    task_function=task_function,
                    config_path=None,
                    config_name=class_name,
                    strict=None,
                )
Example #2
0
        def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any:
            # Check it config was passed.
            if cfg_passthrough is not None:
                return task_function(cfg_passthrough)
            else:
                args = get_args_parser()

                # Parse arguments in order to retrieve overrides
                parsed_args = args.parse_args()

                # Get overriding args in dot string format
                overrides = parsed_args.overrides  # type: list

                # Update overrides
                overrides.append("hydra.run.dir=.")
                overrides.append('hydra.job_logging.root.handlers=null')

                # Wrap a callable object with name `parse_args`
                # This is to mimic the ArgParser.parse_args() API.
                class _argparse_wrapper:
                    def __init__(self, arg_parser):
                        self.arg_parser = arg_parser
                        self._actions = arg_parser._actions

                    def parse_args(self, args=None, namespace=None):
                        return parsed_args

                # no return value from run_hydra() as it may sometime actually run the task_function
                # multiple times (--multirun)
                _run_hydra(
                    args_parser=_argparse_wrapper(args),
                    task_function=task_function,
                    config_path=config_path,
                    config_name=config_name,
                    strict=None,
                )
Example #3
0
        def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any:
            # Check it config was passed.
            if cfg_passthrough is not None:
                return task_function(cfg_passthrough)
            else:
                args = get_args_parser()

                # Parse arguments in order to retrieve overrides
                parsed_args = args.parse_args()  # type: argparse.Namespace

                # Get overriding args in dot string format
                overrides = parsed_args.overrides  # type: list

                # Disable the creation of .hydra subdir
                # https://hydra.cc/docs/tutorials/basic/running_your_app/working_directory
                overrides.append("hydra.output_subdir=null")
                # Hydra logging outputs only to stdout (no log file).
                # https://hydra.cc/docs/configure_hydra/logging
                overrides.append("hydra/job_logging=stdout")

                # Set run.dir ONLY for ExpManager "compatibility" - to be removed.
                overrides.append("hydra.run.dir=.")

                # Check if user set the schema.
                if schema is not None:
                    # Create config store.
                    cs = ConfigStore.instance()

                    # Get the correct ConfigStore "path name" to "inject" the schema.
                    if parsed_args.config_name is not None:
                        path, name = os.path.split(parsed_args.config_name)
                        # Make sure the path is not set - as this will disable validation scheme.
                        if path != '':
                            sys.stderr.write(
                                f"ERROR Cannot set config file path using `--config-name` when "
                                "using schema. Please set path using `--config-path` and file name using "
                                "`--config-name` separately.\n")
                            sys.exit(1)
                    else:
                        name = config_name

                    # Register the configuration as a node under the name in the group.
                    cs.store(name=name, node=schema)  # group=group,

                # Wrap a callable object with name `parse_args`
                # This is to mimic the ArgParser.parse_args() API.
                class _argparse_wrapper:
                    def __init__(self, arg_parser):
                        self.arg_parser = arg_parser
                        self._actions = arg_parser._actions

                    def parse_args(self, args=None, namespace=None):
                        return parsed_args

                # no return value from run_hydra() as it may sometime actually run the task_function
                # multiple times (--multirun)
                _run_hydra(
                    args_parser=_argparse_wrapper(args),
                    task_function=task_function,
                    config_path=config_path,
                    config_name=config_name,
                    strict=None,
                )
Example #4
0
        def decorated_main(
                config: Union[str, dict, list, tuple, DictConfig]) -> Any:
            ### string
            if isinstance(config, string_types):
                # path to a config file
                if os.path.isfile(config):
                    config_name = os.path.basename(config).replace(".yaml", "")
                    config_path = os.path.dirname(config)
                # path to a directory
                elif os.path.isdir(config):
                    config_path = config
                    if not os.path.exists(
                            os.path.join(config_path, 'base.yaml')):
                        config_name = "base"  # default name
                    else:
                        config_name = sorted([
                            i for i in os.listdir(config_path) if '.yaml' in i
                        ])[0].replace(".yaml", "")
                # YAML string
                else:
                    config_path, config_name = _save_config_to_tempdir(config)
            ### dictionary, tuple, list, DictConfig
            else:
                config_path, config_name = _save_config_to_tempdir(config)
            ### list all experiments command
            for a in sys.argv:
                if LIST_PATTERN.match(a) or SUMMARY_PATTERN.match(a):
                    print("Output dir:", output_dir)
                    all_logs = defaultdict(list)
                    for i in os.listdir(log_dir):
                        name, time_str = i.replace('.log', '').split(':')
                        all_logs[name].append(
                            (time_str, os.path.join(log_dir, i)))
                    for fname in sorted(os.listdir(output_dir)):
                        path = os.path.join(output_dir, fname)
                        # basics meta
                        print(
                            f" {fname}", f"({len(os.listdir(path))} files)"
                            if os.path.isdir(path) else "")
                        # show the log files info
                        if fname in all_logs:
                            for time_str, log_file in all_logs[fname]:
                                with open(log_file, 'r') as f:
                                    log_data = f.read()
                                    lines = log_data.split('\n')
                                    n = len(lines)
                                    print(
                                        f'  log {datetime.strptime(time_str, TIME_FMT)} ({n} lines)'
                                    )
                                    for e in [
                                            l for l in lines if '[ERROR]' in l
                                    ]:
                                        print(f'   {e.split("[ERROR]")[1]}')
                    exit()
            ### check if overrides provided
            is_overrided = False
            for a in sys.argv:
                match = OVERRIDE_PATTERN.match(a)
                if match and not any(k in match.string for k in exclude_keys):
                    is_overrided = True
            ### formatting output dirs
            if is_overrided:
                override_id = r"${hydra.job.override_dirname}"
            else:
                override_id = r"default"
            ### check if enable remove exists experiment
            remove_exists = False
            for i, a in enumerate(list(sys.argv)):
                match = REMOVE_EXIST_PATTERN.match(a)
                if match:
                    remove_exists = True
                    sys.argv.pop(i)
                    break
            ### parallel jobs provided
            jobs = 1
            for i, a in enumerate(list(sys.argv)):
                match = JOBS_PATTERN.match(a)
                if match:
                    jobs = int(match.groups()[-1])
                    sys.argv.pop(i)
                    break
            if jobs > 1:
                _insert_argv(key="hydra/launcher",
                             value="joblib",
                             is_value_string=False)
                _insert_argv(key="hydra.launcher.n_jobs",
                             value=f"{jobs}",
                             is_value_string=False)
            ### running dirs
            _insert_argv(key="hydra.run.dir",
                         value=f"{output_dir}/{override_id}",
                         is_value_string=True)
            _insert_argv(key="hydra.sweep.dir",
                         value=f"{output_dir}/multirun/{HYDRA_TIME_FMT}",
                         is_value_string=True)
            _insert_argv(key="hydra.job_logging.handlers.file.filename",
                         value=f"{log_dir}/{override_id}:{HYDRA_TIME_FMT}.log",
                         is_value_string=True)
            _insert_argv(key="hydra.job.config.override_dirname.exclude_keys",
                         value=f"[{','.join([str(i) for i in exclude_keys])}]",
                         is_value_string=False)
            # no return value from run_hydra() as it may sometime actually run the task_function
            # multiple times (--multirun)
            args = get_args_parser()
            config_path = _abspath(config_path)
            ## prepare arguments for task_function
            spec = inspect.getfullargspec(task_function)
            ## run hydra
            @functools.wraps(task_function)
            def _task_function(_cfg):
                # print out the running config
                cfg_text = '\n ----------- \n'
                cfg_text += OmegaConf.to_yaml(_cfg)[:-1]
                cfg_text += '\n -----------'
                logger.info(cfg_text)
                # remove the exists
                if remove_exists:
                    output_dir = get_output_dir()
                    dir_base = os.path.dirname(output_dir)
                    dir_name = os.path.basename(output_dir)
                    for folder in get_all_folder(dir_base):
                        if dir_name == os.path.basename(folder):
                            clear_folder(folder, verbose=True)
                # catch exception, continue running in case
                try:
                    task_function(_cfg)
                except Exception as e:
                    _, value, tb = sys.exc_info()
                    for line in traceback.TracebackException(
                            type(value), value, tb,
                            limit=None).format(chain=None):
                        logger.error(line)
                    if jobs == 1:
                        raise e

            _run_hydra(
                args_parser=args,
                task_function=_task_function,
                config_path=config_path,
                config_name=config_name,
                strict=None,
            )
Example #5
0
 def decorated_main(
     config: Union[str, dict, list, tuple, DictConfig]) -> Any:
   ### string
   if isinstance(config, string_types):
     # path to a config file
     if os.path.isfile(config):
       config_name = os.path.basename(config).replace(".yaml", "")
       config_path = os.path.dirname(config)
     # path to a directory
     elif os.path.isdir(config):
       config_path = config
       if not os.path.exists(os.path.join(config_path, 'base.yaml')):
         config_name = "base"  # default name
       else:
         config_name = sorted([
             i for i in os.listdir(config_path) if '.yaml' in i
         ])[0].replace(".yaml", "")
     # YAML string
     else:
       config_path, config_name = _save_config_to_tempdir(config)
   ### dictionary, tuple, list, DictConfig
   else:
     config_path, config_name = _save_config_to_tempdir(config)
   ### list all experiments command
   for a in sys.argv:
     if LIST_PATTERN.match(a):
       print("Output dir:", output_dir)
       for fname in sorted(os.listdir(output_dir)):
         path = os.path.join(output_dir, fname)
         print(
             f" {fname}", f"({len(os.listdir(path))} files)"
             if os.path.isdir(path) else "")
       exit()
   ### check if overrides provided
   is_overrided = False
   for a in sys.argv:
     match = OVERRIDE_PATTERN.match(a)
     if match and not any(k in match.string for k in exclude_keys):
       is_overrided = True
   ### formatting output dirs
   time_fmt = r"${now:%j_%H%M%S}"
   if is_overrided:
     override_id = r"${hydra.job.override_dirname}"
   else:
     override_id = r"default"
   ### parallel jobs provided
   jobs = 1
   for i, a in enumerate(list(sys.argv)):
     match = JOBS_PATTERN.match(a)
     if match:
       jobs = int(match.groups()[-1])
       sys.argv.pop(i)
       break
   if jobs > 1:
     _insert_argv(key="hydra/launcher",
                  value="joblib",
                  is_value_string=False)
     _insert_argv(key="hydra.launcher.n_jobs",
                  value=f"{jobs}",
                  is_value_string=False)
   ### running dirs
   _insert_argv(key="hydra.run.dir",
                value=f"{output_dir}/{override_id}",
                is_value_string=True)
   _insert_argv(key="hydra.sweep.dir",
                value=f"{output_dir}/multirun/{time_fmt}",
                is_value_string=True)
   _insert_argv(key="hydra.job_logging.handlers.file.filename",
                value=f"{log_dir}/{override_id}:{time_fmt}.log",
                is_value_string=True)
   _insert_argv(key="hydra.job.config.override_dirname.exclude_keys",
                value=f"[{','.join([str(i) for i in exclude_keys])}]",
                is_value_string=False)
   # no return value from run_hydra() as it may sometime actually run the task_function
   # multiple times (--multirun)
   args = get_args_parser()
   config_path = _abspath(config_path)
   ## prepare arguments for task_function
   spec = inspect.getfullargspec(task_function)
   ## run hydra
   _run_hydra(
       args_parser=args,
       task_function=task_function,
       config_path=config_path,
       config_name=config_name,
       strict=None,
   )
Example #6
0
    def decorated_main(
        config: Union[str, dict, list, tuple, DictConfig]) -> Any:
      ## string
      if isinstance(config, string_types):
        if os.path.isfile(config):
          config_name = os.path.basename(config).replace(".yaml", "")
          config_path = os.path.dirname(config)
        elif os.path.isdir(config):
          config_path = config
          if not os.path.exists(os.path.join(config_path, 'base.yaml')):
            config_name = "base"  # default name
          else:
            config_name = sorted([
                i for i in os.listdir(config_path) if '.yaml' in i
            ])[0].replace(".yaml", "")
        elif len(YAML_REGEX.findall(config)) > 1:
          config_path, config_name = _save_config_to_tempdir(config)
        else:
          raise ValueError(
              f"No support for string config with format: {config}")
      ## dictionary, tuple, list, DictConfig
      else:
        config_path, config_name = _save_config_to_tempdir(config)
      ### check if output dir is provided
      # check if overrides provided
      is_overrided = False
      for a in sys.argv:
        match = OVERRIDE_PATTERN.match(a)
        if match and not any(k in match.string for k in exclude_keys):
          is_overrided = True
      # formatting output dirs
      time_fmt = r"${now:%j_%H%M%S}"
      if is_overrided:
        override_id = r"${hydra.job.override_dirname}"
      else:
        override_id = r"default"
      # jobs provided
      jobs = 1
      text = '~'.join(sys.argv)
      match = JOBS_PATTERN.search(text)
      if match:
        jobs = int(match.groups()[0])
        text = re.sub(JOBS_PATTERN, "~", text)
        sys.argv = text.split("~")
      if jobs > 1:
        _insert_argv(key="hydra/launcher",
                     value="joblib",
                     is_value_string=False)
        _insert_argv(key="hydra.launcher.n_jobs",
                     value=f"{jobs}",
                     is_value_string=False)
      # running dirs
      _insert_argv(key="hydra.run.dir",
                   value=f"{output_dir}/{override_id}",
                   is_value_string=True)
      _insert_argv(key="hydra.sweep.dir",
                   value=f"{output_dir}/multirun/{time_fmt}",
                   is_value_string=True)
      _insert_argv(key="hydra.job_logging.handlers.file.filename",
                   value=f"{log_dir}/{override_id}:{time_fmt}.log",
                   is_value_string=True)
      _insert_argv(key="hydra.job.config.override_dirname.exclude_keys",
                   value=f"[{','.join([str(i) for i in exclude_keys])}]",
                   is_value_string=False)
      # no return value from run_hydra() as it may sometime actually run the task_function
      # multiple times (--multirun)
      args = get_args_parser()
      config_path = _abspath(config_path)
      ## prepare arguments for task_function
      spec = inspect.getfullargspec(task_function)

      ## run hydra
      _run_hydra(
          args_parser=args,
          task_function=task_function,
          config_path=config_path,
          config_name=config_name,
          strict=None,
      )