예제 #1
0
    def __init__(self, fn):
        """
        load a yaml config of a job and save generated configs as yaml for each task.
        return: a list of files to run as specified by `run_task`.
        """
        if fn.endswith(".py"):
            # a python command.
            self.backend = "python"
            self.run_yamls = [fn]
            return

        job_config = recursive_config(fn)
        if job_config.base_dir is None:  # single file job config.
            self.run_yamls = [fn]
            return

        self.project_dir = os.path.join("projects", job_config.project_dir)
        self.run_dir = os.path.join("runs", job_config.project_dir)

        if job_config.run_task is not None:
            run_yamls = []
            for stage in job_config.run_task:
                # each stage can have multiple tasks running in parallel.
                if OmegaConf.is_list(stage):
                    stage_yamls = []
                    for task_file in stage:
                        stage_yamls.append(
                            os.path.join(self.project_dir, task_file))
                    run_yamls.append(stage_yamls)
                else:
                    run_yamls.append(os.path.join(self.project_dir, stage))
            self.run_yamls = run_yamls
        configs_to_save = self._overwrite_task(job_config)
        self._save_configs(configs_to_save)
예제 #2
0
def main(args):
    config = recursive_config(args.config).dataset

    out_file = os.path.splitext(config.caption_pkl_path)[0] \
        + "." + config.bert_name + ".pkl"
    if not os.path.isfile(out_file):
        tokenize(config, out_file)
    sharding(config, out_file)
예제 #3
0
    def __init__(self, yaml_file):
        self.yaml_file = yaml_file
        job_key = "local"

        if yaml_file.endswith(".yaml"):
            config = recursive_config(yaml_file)
            if config.task_type is not None:
                job_key = config.task_type.split("_")[0]
        else:
            raise ValueError("unknown extension of job file:", yaml_file)
        self.job_key = job_key
예제 #4
0
    def _overwrite_task(self, job_config):
        configs_to_save = {}
        self.base_project_dir = os.path.join("projects", job_config.base_dir)
        self.base_run_dir = os.path.join("runs", job_config.base_dir)

        for config_sets in job_config.task_group:
            overwrite_config = job_config.task_group[config_sets]
            if (overwrite_config.task_list is None
                    or len(overwrite_config.task_list) == 0):
                print("[warning]", job_config.task_group,
                      "has no task_list specified.")
            # we don't want this added to a final config.
            task_list = overwrite_config.pop("task_list", None)
            for config_file in task_list:
                config_file_path = os.path.join(self.base_project_dir,
                                                config_file)
                config = recursive_config(config_file_path)
                # overwrite it.
                if overwrite_config:
                    config = OmegaConf.merge(config, overwrite_config)
                overwrite_dir(config, self.run_dir, basedir=self.base_run_dir)
                save_file_path = os.path.join(self.project_dir, config_file)
                configs_to_save[save_file_path] = config
        return configs_to_save
예제 #5
0
 def __init__(self, yaml_file, dryrun=False):
     self.yaml_file = yaml_file
     self.config = recursive_config(yaml_file)
     self.dryrun = dryrun