コード例 #1
0
ファイル: test_config.py プロジェクト: Steffen-Wolf/seml
 def test_generate_configs(self):
     config_dict = self.load_config_dict(self.CONFIG_WITH_ALL_TYPES)
     configs = config.generate_configs(config_dict)
     assert len(configs) == 22
     expected_configs = [
         *(5 * [{
             'a': 9999,
             'b': 7777,
             'c': 1234,
             'd': 1.0,
             'e': 2.0
         }, {
             'a': 9999,
             'b': 7777,
             'c': 5678,
             'd': 1.0,
             'e': 2.0
         }]),
         *(3 * [{
             'a': 333,
             'b': 444,
             'c': 555,
             'd': 1.0,
             'f': 9199
         }, {
             'a': 333,
             'b': 444,
             'c': 555,
             'd': 1.0,
             'f': 1099
         }, {
             'a': 333,
             'b': 444,
             'c': 666,
             'd': 1.0,
             'f': 9199
         }, {
             'a': 333,
             'b': 444,
             'c': 666,
             'd': 1.0,
             'f': 1099
         }]),
     ]
     expected_config_hashes = sorted(
         [utils.make_hash(x) for x in expected_configs])
     actual_config_hashes = sorted([utils.make_hash(x) for x in configs])
     assert expected_config_hashes == actual_config_hashes
コード例 #2
0
def add_configs(collection,
                seml_config,
                slurm_config,
                configs,
                source_files=None,
                git_info=None):
    """Put the input configurations into the database.

    Parameters
    ----------
    collection: pymongo.collection.Collection
        The MongoDB collection containing the experiments.
    seml_config: dict
        Configuration for the SEML library.
    slurm_config: dict
        Settings for the Slurm job. See `start_experiments.start_slurm_job` for details.
    configs: list of dicts
        Contains the parameter configurations.
    source_files: (optional) list of tuples
        Contains the uploaded source files corresponding to the batch. Entries are of the form
        (object_id, relative_path)
    git_info: (Optional) dict containing information about the git repo status.

    Returns
    -------
    None

    """

    if len(configs) == 0:
        return

    start_id = get_max_in_collection(collection, "_id")
    if start_id is None:
        start_id = 1
    else:
        start_id = start_id + 1

    batch_id = get_max_in_collection(collection, "batch_id")
    if batch_id is None:
        batch_id = 1
    else:
        batch_id = batch_id + 1

    logging.info(
        f"Adding {len(configs)} configs to the database (batch-ID {batch_id})."
    )

    if source_files is not None:
        seml_config['source_files'] = source_files
    db_dicts = [{
        '_id': start_id + ix,
        'batch_id': batch_id,
        'status': States.STAGED[0],
        'seml': seml_config,
        'slurm': slurm_config,
        'config': c,
        'config_hash': make_hash(c),
        'git': git_info,
        'add_time': datetime.datetime.utcnow()
    } for ix, c in enumerate(configs)]

    collection.insert_many(db_dicts)
コード例 #3
0
def add_experiments(db_collection_name,
                    config_file,
                    force_duplicates,
                    no_hash=False,
                    no_sanity_check=False,
                    no_code_checkpoint=False):
    """
    Add configurations from a config file into the database.

    Parameters
    ----------
    db_collection_name: the MongoDB collection name.
    config_file: path to the YAML configuration.
    force_duplicates: if True, disable duplicate detection.
    no_hash: if True, disable hashing of the configurations for duplicate detection. This is much slower, so use only
        if you have a good reason to.
    no_sanity_check: if True, do not check the config for missing/unused arguments.
    no_code_checkpoint: if True, do not upload the experiment source code files to the MongoDB.

    Returns
    -------
    None
    """

    seml_config, slurm_config, experiment_config = read_config(config_file)

    # Use current Anaconda environment if not specified
    if 'conda_environment' not in seml_config:
        if 'CONDA_DEFAULT_ENV' in os.environ:
            seml_config['conda_environment'] = os.environ['CONDA_DEFAULT_ENV']
        else:
            seml_config['conda_environment'] = None

    # Set Slurm config with default parameters as fall-back option
    if slurm_config is None:
        slurm_config = {'sbatch_options': {}}
    for k, v in SETTINGS.SLURM_DEFAULT['sbatch_options'].items():
        if k not in slurm_config['sbatch_options']:
            slurm_config['sbatch_options'][k] = v
    del SETTINGS.SLURM_DEFAULT['sbatch_options']
    for k, v in SETTINGS.SLURM_DEFAULT.items():
        if k not in slurm_config:
            slurm_config[k] = v

    slurm_config['sbatch_options'] = remove_prepended_dashes(
        slurm_config['sbatch_options'])
    configs = generate_configs(experiment_config)
    collection = get_collection(db_collection_name)

    batch_id = get_max_in_collection(collection, "batch_id")
    if batch_id is None:
        batch_id = 1
    else:
        batch_id = batch_id + 1

    if seml_config['use_uploaded_sources'] and not no_code_checkpoint:
        uploaded_files = upload_sources(seml_config, collection, batch_id)
    else:
        uploaded_files = None

    if not no_sanity_check:
        check_config(seml_config['executable'],
                     seml_config['conda_environment'], configs)

    path, commit, dirty = get_git_info(seml_config['executable'])
    git_info = None
    if path is not None:
        git_info = {'path': path, 'commit': commit, 'dirty': dirty}

    use_hash = not no_hash
    if use_hash:
        configs = [{**c, **{'config_hash': make_hash(c)}} for c in configs]

    if not force_duplicates:
        len_before = len(configs)

        # First, check for duplicates withing the experiment configurations from the file.
        if not use_hash:
            # slow duplicate detection without hashes
            unique_configs = []
            for c in configs:
                if c not in unique_configs:
                    unique_configs.append(c)
            configs = unique_configs
        else:
            # fast duplicate detection using hashing.
            configs_dict = {c['config_hash']: c for c in configs}
            configs = [v for k, v in configs_dict.items()]

        len_after_deduplication = len(configs)
        # Now, check for duplicate configurations in the database.
        configs = filter_experiments(collection, configs)
        len_after = len(configs)
        if len_after_deduplication != len_before:
            logging.info(
                f"{len_before - len_after_deduplication} of {len_before} experiment{s_if(len_before)} were "
                f"duplicates. Adding only the {len_after_deduplication} unique configurations."
            )
        if len_after != len_after_deduplication:
            logging.info(
                f"{len_after_deduplication - len_after} of {len_after_deduplication} "
                f"experiment{s_if(len_before)} were already found in the database. They were not added again."
            )

    # Create an index on the config hash. If the index is already present, this simply does nothing.
    collection.create_index("config_hash")
    # Add the configurations to the database with STAGED status.
    if len(configs) > 0:
        add_configs(collection, seml_config, slurm_config, configs,
                    uploaded_files, git_info)