Exemple #1
0
        def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any:
            # Check it config was passed.
            if cfg_passthrough is not None:
                return task_function(cfg_passthrough)
            else:
                args = get_args_parser()

                # no return value from run_hydra() as it may sometime actually run the task_function
                # multiple times (--multirun)
                _run_hydra(
                    args_parser=args,
                    task_function=task_function,
                    config_path=None,
                    config_name=class_name,
                    strict=None,
                )
Exemple #2
0
        def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any:
            # Check it config was passed.
            if cfg_passthrough is not None:
                return task_function(cfg_passthrough)
            else:
                args = get_args_parser()

                # Parse arguments in order to retrieve overrides
                parsed_args = args.parse_args()

                # Get overriding args in dot string format
                overrides = parsed_args.overrides  # type: list

                # Update overrides
                overrides.append("hydra.run.dir=.")
                overrides.append('hydra.job_logging.root.handlers=null')

                # Wrap a callable object with name `parse_args`
                # This is to mimic the ArgParser.parse_args() API.
                class _argparse_wrapper:
                    def __init__(self, arg_parser):
                        self.arg_parser = arg_parser
                        self._actions = arg_parser._actions

                    def parse_args(self, args=None, namespace=None):
                        return parsed_args

                # no return value from run_hydra() as it may sometime actually run the task_function
                # multiple times (--multirun)
                _run_hydra(
                    args_parser=_argparse_wrapper(args),
                    task_function=task_function,
                    config_path=config_path,
                    config_name=config_name,
                    strict=None,
                )
Exemple #3
0
        def wrapper(cfg_passthrough: Optional[DictConfig] = None) -> Any:
            # Check it config was passed.
            if cfg_passthrough is not None:
                return task_function(cfg_passthrough)
            else:
                args = get_args_parser()

                # Parse arguments in order to retrieve overrides
                parsed_args = args.parse_args()  # type: argparse.Namespace

                # Get overriding args in dot string format
                overrides = parsed_args.overrides  # type: list

                # Disable the creation of .hydra subdir
                # https://hydra.cc/docs/tutorials/basic/running_your_app/working_directory
                overrides.append("hydra.output_subdir=null")
                # Hydra logging outputs only to stdout (no log file).
                # https://hydra.cc/docs/configure_hydra/logging
                overrides.append("hydra/job_logging=stdout")

                # Set run.dir ONLY for ExpManager "compatibility" - to be removed.
                overrides.append("hydra.run.dir=.")

                # Check if user set the schema.
                if schema is not None:
                    # Create config store.
                    cs = ConfigStore.instance()

                    # Get the correct ConfigStore "path name" to "inject" the schema.
                    if parsed_args.config_name is not None:
                        path, name = os.path.split(parsed_args.config_name)
                        # Make sure the path is not set - as this will disable validation scheme.
                        if path != '':
                            sys.stderr.write(
                                f"ERROR Cannot set config file path using `--config-name` when "
                                "using schema. Please set path using `--config-path` and file name using "
                                "`--config-name` separately.\n")
                            sys.exit(1)
                    else:
                        name = config_name

                    # Register the configuration as a node under the name in the group.
                    cs.store(name=name, node=schema)  # group=group,

                # Wrap a callable object with name `parse_args`
                # This is to mimic the ArgParser.parse_args() API.
                class _argparse_wrapper:
                    def __init__(self, arg_parser):
                        self.arg_parser = arg_parser
                        self._actions = arg_parser._actions

                    def parse_args(self, args=None, namespace=None):
                        return parsed_args

                # no return value from run_hydra() as it may sometime actually run the task_function
                # multiple times (--multirun)
                _run_hydra(
                    args_parser=_argparse_wrapper(args),
                    task_function=task_function,
                    config_path=config_path,
                    config_name=config_name,
                    strict=None,
                )
Exemple #4
0
    def run(self, overrides=[], ncpu=None, **configs):
        r"""

    Arguments:
      strict: A Boolean, strict configurations prevent the access to
        unknown key, otherwise, the config will return `None`.

    Example:
      exp = SisuaExperimenter(ncpu=1)
      exp.run(
          overrides={
              'model': ['sisua', 'dca', 'vae'],
              'dataset.name': ['cortex', 'pbmc8kly'],
              'train.verbose': 0,
              'train.epochs': 2,
              'train': ['adam'],
          })
    """
        overrides = _overrides(overrides) + _overrides(configs)
        strict = False
        command = ' '.join(sys.argv)
        # parse ncpu
        if ncpu is None:
            ncpu = self.ncpu
        ncpu = int(ncpu)
        for idx, arg in enumerate(list(sys.argv)):
            if 'ncpu' in arg:
                if '=' in arg:
                    ncpu = int(arg.split('=')[-1])
                    sys.argv.pop(idx)
                else:
                    ncpu = int(sys.argv[idx + 1])
                    sys.argv.pop(idx)
                    sys.argv.pop(idx)
                break
        # check reset
        for idx, arg in enumerate(list(sys.argv)):
            if arg in ('--reset', '--clear', '--clean'):
                configs_filter = lambda f: 'configs' != f.split('/')[-1]
                if len(
                        get_all_files(self._save_path,
                                      filter_func=configs_filter)) > 0:
                    old_exps = '\n'.join([
                        " - %s" % i for i in os.listdir(self._save_path)
                        if configs_filter(i)
                    ])
                    inp = input("<Enter> to clear all exists experiments:"
                                "\n%s\n'n' to cancel, otherwise continue:" %
                                old_exps)
                    if inp.strip().lower() != 'n':
                        clean_folder(self._save_path,
                                     filter=configs_filter,
                                     verbose=True)
                sys.argv.pop(idx)
        # check multirun
        is_multirun = any(',' in ovr for ovr in overrides) or \
          any(',' in arg and '=' in arg for arg in sys.argv)
        # write history
        self.write_history(command, "overrides: %s" % str(overrides),
                           "strict: %s" % str(strict), "ncpu: %d" % ncpu,
                           "multirun: %s" % str(is_multirun))
        # generate app help
        hlp = '\n\n'.join([
            "%s - %s" % (str(key), ', '.join(sorted(as_tuple(val, t=str))))
            for key, val in dict(self.args_help).items()
        ])

        def _run(self, config_file, task_function, overrides):
            if is_multirun:
                raise RuntimeError(
                    "Performing single run with multiple overrides in hydra "
                    "(use '-m' for multirun): %s" % str(overrides))
            cfg = self.compose_config(config_file=config_file,
                                      overrides=overrides,
                                      strict=strict,
                                      with_log_configuration=True)
            HydraConfig().set_config(cfg)
            return run_job(
                config=cfg,
                task_function=task_function,
                job_dir_key="hydra.run.dir",
                job_subdir_key=None,
            )

        def _multirun(self, config_file, task_function, overrides):
            # Initial config is loaded without strict (individual job configs may have strict).
            from hydra._internal.plugins import Plugins
            cfg = self.compose_config(config_file=config_file,
                                      overrides=overrides,
                                      strict=strict,
                                      with_log_configuration=True)
            HydraConfig().set_config(cfg)
            sweeper = Plugins.instantiate_sweeper(
                config=cfg,
                config_loader=self.config_loader,
                task_function=task_function)
            # override launcher for using multiprocessing
            sweeper.launcher = ParallelLauncher(ncpu=ncpu)
            sweeper.launcher.setup(config=cfg,
                                   config_loader=self.config_loader,
                                   task_function=task_function)
            return sweeper.sweep(arguments=cfg.hydra.overrides.task)

        old_multirun = (Hydra.run, Hydra.multirun)
        Hydra.run = _run
        Hydra.multirun = _multirun

        try:
            # append the new override
            if len(overrides) > 0:
                sys.argv += overrides
            # help for arguments
            if '--help' in sys.argv:
                # sys.argv.append("hydra.help.header='**** %s ****'" %
                #                 self.__class__.__name__)
                # sys.argv.append("hydra.help.template=%s" % (_APP_HELP % hlp))
                # TODO : fix bug here
                pass
            # append the hydra log path
            job_fmt = "/${now:%d%b%y_%H%M%S}"
            sys.argv.insert(
                1, "hydra.run.dir=%s" % self.get_hydra_path() + job_fmt)
            sys.argv.insert(
                1, "hydra.sweep.dir=%s" % self.get_hydra_path() + job_fmt)
            sys.argv.insert(1, "hydra.sweep.subdir=${hydra.job.id}")
            # sys.argv.append(r"hydra.job_logging.formatters.simple.format=" +
            #                 r"[\%(asctime)s][\%(name)s][\%(levelname)s] - \%(message)s")
            args_parser = get_args_parser()
            run_hydra(
                args_parser=args_parser,
                task_function=self._run,
                config_path=self.config_path,
                strict=strict,
            )
        except KeyboardInterrupt:
            sys.exit(-1)
        except SystemExit:
            pass
        Hydra.run = old_multirun[0]
        Hydra.multirun = old_multirun[1]
        # update the summary
        self.summary()
        return self
Exemple #5
0
        def decorated_main(
                config: Union[str, dict, list, tuple, DictConfig]) -> Any:
            ### string
            if isinstance(config, string_types):
                # path to a config file
                if os.path.isfile(config):
                    config_name = os.path.basename(config).replace(".yaml", "")
                    config_path = os.path.dirname(config)
                # path to a directory
                elif os.path.isdir(config):
                    config_path = config
                    if not os.path.exists(
                            os.path.join(config_path, 'base.yaml')):
                        config_name = "base"  # default name
                    else:
                        config_name = sorted([
                            i for i in os.listdir(config_path) if '.yaml' in i
                        ])[0].replace(".yaml", "")
                # YAML string
                else:
                    config_path, config_name = _save_config_to_tempdir(config)
            ### dictionary, tuple, list, DictConfig
            else:
                config_path, config_name = _save_config_to_tempdir(config)
            ### list all experiments command
            for a in sys.argv:
                if LIST_PATTERN.match(a) or SUMMARY_PATTERN.match(a):
                    print("Output dir:", output_dir)
                    all_logs = defaultdict(list)
                    for i in os.listdir(log_dir):
                        name, time_str = i.replace('.log', '').split(':')
                        all_logs[name].append(
                            (time_str, os.path.join(log_dir, i)))
                    for fname in sorted(os.listdir(output_dir)):
                        path = os.path.join(output_dir, fname)
                        # basics meta
                        print(
                            f" {fname}", f"({len(os.listdir(path))} files)"
                            if os.path.isdir(path) else "")
                        # show the log files info
                        if fname in all_logs:
                            for time_str, log_file in all_logs[fname]:
                                with open(log_file, 'r') as f:
                                    log_data = f.read()
                                    lines = log_data.split('\n')
                                    n = len(lines)
                                    print(
                                        f'  log {datetime.strptime(time_str, TIME_FMT)} ({n} lines)'
                                    )
                                    for e in [
                                            l for l in lines if '[ERROR]' in l
                                    ]:
                                        print(f'   {e.split("[ERROR]")[1]}')
                    exit()
            ### check if overrides provided
            is_overrided = False
            for a in sys.argv:
                match = OVERRIDE_PATTERN.match(a)
                if match and not any(k in match.string for k in exclude_keys):
                    is_overrided = True
            ### formatting output dirs
            if is_overrided:
                override_id = r"${hydra.job.override_dirname}"
            else:
                override_id = r"default"
            ### check if enable remove exists experiment
            remove_exists = False
            for i, a in enumerate(list(sys.argv)):
                match = REMOVE_EXIST_PATTERN.match(a)
                if match:
                    remove_exists = True
                    sys.argv.pop(i)
                    break
            ### parallel jobs provided
            jobs = 1
            for i, a in enumerate(list(sys.argv)):
                match = JOBS_PATTERN.match(a)
                if match:
                    jobs = int(match.groups()[-1])
                    sys.argv.pop(i)
                    break
            if jobs > 1:
                _insert_argv(key="hydra/launcher",
                             value="joblib",
                             is_value_string=False)
                _insert_argv(key="hydra.launcher.n_jobs",
                             value=f"{jobs}",
                             is_value_string=False)
            ### running dirs
            _insert_argv(key="hydra.run.dir",
                         value=f"{output_dir}/{override_id}",
                         is_value_string=True)
            _insert_argv(key="hydra.sweep.dir",
                         value=f"{output_dir}/multirun/{HYDRA_TIME_FMT}",
                         is_value_string=True)
            _insert_argv(key="hydra.job_logging.handlers.file.filename",
                         value=f"{log_dir}/{override_id}:{HYDRA_TIME_FMT}.log",
                         is_value_string=True)
            _insert_argv(key="hydra.job.config.override_dirname.exclude_keys",
                         value=f"[{','.join([str(i) for i in exclude_keys])}]",
                         is_value_string=False)
            # no return value from run_hydra() as it may sometime actually run the task_function
            # multiple times (--multirun)
            args = get_args_parser()
            config_path = _abspath(config_path)
            ## prepare arguments for task_function
            spec = inspect.getfullargspec(task_function)
            ## run hydra
            @functools.wraps(task_function)
            def _task_function(_cfg):
                # print out the running config
                cfg_text = '\n ----------- \n'
                cfg_text += OmegaConf.to_yaml(_cfg)[:-1]
                cfg_text += '\n -----------'
                logger.info(cfg_text)
                # remove the exists
                if remove_exists:
                    output_dir = get_output_dir()
                    dir_base = os.path.dirname(output_dir)
                    dir_name = os.path.basename(output_dir)
                    for folder in get_all_folder(dir_base):
                        if dir_name == os.path.basename(folder):
                            clear_folder(folder, verbose=True)
                # catch exception, continue running in case
                try:
                    task_function(_cfg)
                except Exception as e:
                    _, value, tb = sys.exc_info()
                    for line in traceback.TracebackException(
                            type(value), value, tb,
                            limit=None).format(chain=None):
                        logger.error(line)
                    if jobs == 1:
                        raise e

            _run_hydra(
                args_parser=args,
                task_function=_task_function,
                config_path=config_path,
                config_name=config_name,
                strict=None,
            )
Exemple #6
0
 def decorated_main(
     config: Union[str, dict, list, tuple, DictConfig]) -> Any:
   ### string
   if isinstance(config, string_types):
     # path to a config file
     if os.path.isfile(config):
       config_name = os.path.basename(config).replace(".yaml", "")
       config_path = os.path.dirname(config)
     # path to a directory
     elif os.path.isdir(config):
       config_path = config
       if not os.path.exists(os.path.join(config_path, 'base.yaml')):
         config_name = "base"  # default name
       else:
         config_name = sorted([
             i for i in os.listdir(config_path) if '.yaml' in i
         ])[0].replace(".yaml", "")
     # YAML string
     else:
       config_path, config_name = _save_config_to_tempdir(config)
   ### dictionary, tuple, list, DictConfig
   else:
     config_path, config_name = _save_config_to_tempdir(config)
   ### list all experiments command
   for a in sys.argv:
     if LIST_PATTERN.match(a):
       print("Output dir:", output_dir)
       for fname in sorted(os.listdir(output_dir)):
         path = os.path.join(output_dir, fname)
         print(
             f" {fname}", f"({len(os.listdir(path))} files)"
             if os.path.isdir(path) else "")
       exit()
   ### check if overrides provided
   is_overrided = False
   for a in sys.argv:
     match = OVERRIDE_PATTERN.match(a)
     if match and not any(k in match.string for k in exclude_keys):
       is_overrided = True
   ### formatting output dirs
   time_fmt = r"${now:%j_%H%M%S}"
   if is_overrided:
     override_id = r"${hydra.job.override_dirname}"
   else:
     override_id = r"default"
   ### parallel jobs provided
   jobs = 1
   for i, a in enumerate(list(sys.argv)):
     match = JOBS_PATTERN.match(a)
     if match:
       jobs = int(match.groups()[-1])
       sys.argv.pop(i)
       break
   if jobs > 1:
     _insert_argv(key="hydra/launcher",
                  value="joblib",
                  is_value_string=False)
     _insert_argv(key="hydra.launcher.n_jobs",
                  value=f"{jobs}",
                  is_value_string=False)
   ### running dirs
   _insert_argv(key="hydra.run.dir",
                value=f"{output_dir}/{override_id}",
                is_value_string=True)
   _insert_argv(key="hydra.sweep.dir",
                value=f"{output_dir}/multirun/{time_fmt}",
                is_value_string=True)
   _insert_argv(key="hydra.job_logging.handlers.file.filename",
                value=f"{log_dir}/{override_id}:{time_fmt}.log",
                is_value_string=True)
   _insert_argv(key="hydra.job.config.override_dirname.exclude_keys",
                value=f"[{','.join([str(i) for i in exclude_keys])}]",
                is_value_string=False)
   # no return value from run_hydra() as it may sometime actually run the task_function
   # multiple times (--multirun)
   args = get_args_parser()
   config_path = _abspath(config_path)
   ## prepare arguments for task_function
   spec = inspect.getfullargspec(task_function)
   ## run hydra
   _run_hydra(
       args_parser=args,
       task_function=task_function,
       config_path=config_path,
       config_name=config_name,
       strict=None,
   )
Exemple #7
0
    def decorated_main(
        config: Union[str, dict, list, tuple, DictConfig]) -> Any:
      ## string
      if isinstance(config, string_types):
        if os.path.isfile(config):
          config_name = os.path.basename(config).replace(".yaml", "")
          config_path = os.path.dirname(config)
        elif os.path.isdir(config):
          config_path = config
          if not os.path.exists(os.path.join(config_path, 'base.yaml')):
            config_name = "base"  # default name
          else:
            config_name = sorted([
                i for i in os.listdir(config_path) if '.yaml' in i
            ])[0].replace(".yaml", "")
        elif len(YAML_REGEX.findall(config)) > 1:
          config_path, config_name = _save_config_to_tempdir(config)
        else:
          raise ValueError(
              f"No support for string config with format: {config}")
      ## dictionary, tuple, list, DictConfig
      else:
        config_path, config_name = _save_config_to_tempdir(config)
      ### check if output dir is provided
      # check if overrides provided
      is_overrided = False
      for a in sys.argv:
        match = OVERRIDE_PATTERN.match(a)
        if match and not any(k in match.string for k in exclude_keys):
          is_overrided = True
      # formatting output dirs
      time_fmt = r"${now:%j_%H%M%S}"
      if is_overrided:
        override_id = r"${hydra.job.override_dirname}"
      else:
        override_id = r"default"
      # jobs provided
      jobs = 1
      text = '~'.join(sys.argv)
      match = JOBS_PATTERN.search(text)
      if match:
        jobs = int(match.groups()[0])
        text = re.sub(JOBS_PATTERN, "~", text)
        sys.argv = text.split("~")
      if jobs > 1:
        _insert_argv(key="hydra/launcher",
                     value="joblib",
                     is_value_string=False)
        _insert_argv(key="hydra.launcher.n_jobs",
                     value=f"{jobs}",
                     is_value_string=False)
      # running dirs
      _insert_argv(key="hydra.run.dir",
                   value=f"{output_dir}/{override_id}",
                   is_value_string=True)
      _insert_argv(key="hydra.sweep.dir",
                   value=f"{output_dir}/multirun/{time_fmt}",
                   is_value_string=True)
      _insert_argv(key="hydra.job_logging.handlers.file.filename",
                   value=f"{log_dir}/{override_id}:{time_fmt}.log",
                   is_value_string=True)
      _insert_argv(key="hydra.job.config.override_dirname.exclude_keys",
                   value=f"[{','.join([str(i) for i in exclude_keys])}]",
                   is_value_string=False)
      # no return value from run_hydra() as it may sometime actually run the task_function
      # multiple times (--multirun)
      args = get_args_parser()
      config_path = _abspath(config_path)
      ## prepare arguments for task_function
      spec = inspect.getfullargspec(task_function)

      ## run hydra
      _run_hydra(
          args_parser=args,
          task_function=task_function,
          config_path=config_path,
          config_name=config_name,
          strict=None,
      )