Exemple #1
0
def reset_experiments(db_collection_name, sacred_id, filter_states, batch_id, filter_dict):
    collection = get_collection(db_collection_name)

    if sacred_id is None:
        if len({*States.PENDING, *States.RUNNING, *States.KILLED} & set(filter_states)) > 0:
            detect_killed(db_collection_name, print_detected=False)

        if isinstance(filter_states, str):
            filter_states = [filter_states]

        filter_dict = build_filter_dict(filter_states, batch_id, filter_dict)

        nreset = collection.count_documents(filter_dict)
        exps = collection.find(filter_dict)

        if nreset >= 10:
            if input(f"Resetting the state of {nreset} experiment{s_if(nreset)}. "
                     f"Are you sure? (y/n) ").lower() != "y":
                exit()
        else:
            logging.info(f"Resetting the state of {nreset} experiment{s_if(nreset)}.")
        for exp in exps:
            reset_single_experiment(collection, exp)
    else:
        exp = collection.find_one({'_id': sacred_id})
        if exp is None:
            raise MongoDBError(f"No experiment found with ID {sacred_id}.")
        else:
            logging.info(f"Resetting the state of experiment with ID {sacred_id}.")
            reset_single_experiment(collection, exp)
Exemple #2
0
def report_status(db_collection_name):
    detect_killed(db_collection_name, print_detected=False)
    collection = get_collection(db_collection_name)
    staged = collection.count_documents({'status': {'$in': States.STAGED}})
    pending = collection.count_documents({'status': {'$in': States.PENDING}})
    failed = collection.count_documents({'status': {'$in': States.FAILED}})
    killed = collection.count_documents({'status': {'$in': States.KILLED}})
    interrupted = collection.count_documents(
        {'status': {
            '$in': States.INTERRUPTED
        }})
    running = collection.count_documents({'status': {'$in': States.RUNNING}})
    completed = collection.count_documents(
        {'status': {
            '$in': States.COMPLETED
        }})
    title = f"********** Report for database collection '{db_collection_name}' **********"
    logging.info(title)
    logging.info(f"*     - {staged:3d} staged experiment{s_if(staged)}")
    logging.info(f"*     - {pending:3d} pending experiment{s_if(pending)}")
    logging.info(f"*     - {running:3d} running experiment{s_if(running)}")
    logging.info(
        f"*     - {completed:3d} completed experiment{s_if(completed)}")
    logging.info(
        f"*     - {interrupted:3d} interrupted experiment{s_if(interrupted)}")
    logging.info(f"*     - {failed:3d} failed experiment{s_if(failed)}")
    logging.info(f"*     - {killed:3d} killed experiment{s_if(killed)}")
    logging.info("*" * len(title))
Exemple #3
0
def delete_experiments(db_collection_name, sacred_id, filter_states, batch_id, filter_dict):
    collection = get_collection(db_collection_name)
    if sacred_id is None:
        if len({*States.PENDING, *States.RUNNING, *States.KILLED} & set(filter_states)) > 0:
            detect_killed(db_collection_name, print_detected=False)

        filter_dict = build_filter_dict(filter_states, batch_id, filter_dict)
        ndelete = collection.count_documents(filter_dict)
        batch_ids = collection.find(filter_dict, {'batch_id'})
        batch_ids_in_del = set([x['batch_id'] for x in batch_ids])

        if ndelete >= 10:
            if input(f"Deleting {ndelete} configuration{s_if(ndelete)} from database collection. "
                     f"Are you sure? (y/n) ").lower() != "y":
                exit()
        else:
            logging.info(f"Deleting {ndelete} configuration{s_if(ndelete)} from database collection.")
        collection.delete_many(filter_dict)
    else:
        exp = collection.find_one({'_id': sacred_id})
        if exp is None:
            raise MongoDBError(f"No experiment found with ID {sacred_id}.")
        else:
            logging.info(f"Deleting experiment with ID {sacred_id}.")
            batch_ids_in_del = set([exp['batch_id']])
            collection.delete_one({'_id': sacred_id})

    if len(batch_ids_in_del) > 0:
        # clean up the uploaded sources if no experiments of a batch remain
        delete_orphaned_sources(collection, batch_ids_in_del)
Exemple #4
0
def detect_killed(db_collection_name, print_detected=True):
    collection = get_collection(db_collection_name)
    exps = collection.find({
        'status': {
            '$in': [*States.PENDING, *States.RUNNING]
        },
        '$or': [{
            'slurm.array_id': {
                '$exists': True
            }
        }, {
            'slurm.id': {
                '$exists': True
            }
        }]
    })
    running_jobs = get_slurm_arrays_tasks()
    nkilled = 0
    for exp in exps:
        exp_running = ('array_id' in exp['slurm']
                       and exp['slurm']['array_id'] in running_jobs and
                       (any(exp['slurm']['task_id'] in r
                            for r in running_jobs[exp['slurm']['array_id']][0])
                        or exp['slurm']['task_id']
                        in running_jobs[exp['slurm']['array_id']][1]))
        if not exp_running:
            if 'stop_time' in exp:
                collection.update_one(
                    {'_id': exp['_id']},
                    {'$set': {
                        'status': States.INTERRUPTED[0]
                    }})
            else:
                nkilled += 1
                collection.update_one({'_id': exp['_id']},
                                      {'$set': {
                                          'status': States.KILLED[0]
                                      }})
                try:
                    with open(exp['seml']['output_file'], 'r') as f:
                        all_lines = f.readlines()
                    collection.update_one(
                        {'_id': exp['_id']},
                        {'$set': {
                            'fail_trace': all_lines[-4:]
                        }})
                except IOError:
                    # If the experiment is cancelled before starting (e.g. when still queued), there is not output file.
                    logging.verbose(
                        f"File {exp['seml']['output_file']} could not be read."
                    )
    if print_detected:
        logging.info(
            f"Detected {nkilled} externally killed experiment{s_if(nkilled)}.")
Exemple #5
0
def collect_exp_stats(run):
    """
    Collect information such as CPU user time, maximum memory usage,
    and maximum GPU memory usage and save it in the MongoDB.

    Parameters
    ----------
    run: Sacred run
        Current Sacred run.

    Returns
    -------
    None
    """
    exp_id = run.config['overwrite']
    if exp_id is None or run.unobserved:
        return

    stats = {}

    stats['real_time'] = (datetime.datetime.utcnow() - run.start_time).total_seconds()

    stats['self'] = {}
    stats['self']['user_time'] = resource.getrusage(resource.RUSAGE_SELF).ru_utime
    stats['self']['system_time'] = resource.getrusage(resource.RUSAGE_SELF).ru_stime
    stats['self']['max_memory_bytes'] = 1024 * resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
    stats['children'] = {}
    stats['children']['user_time'] = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime
    stats['children']['system_time'] = resource.getrusage(resource.RUSAGE_CHILDREN).ru_stime
    stats['children']['max_memory_bytes'] = 1024 * resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss

    if 'torch' in sys.modules:
        import torch
        stats['pytorch'] = {}
        if torch.cuda.is_available():
            stats['pytorch']['gpu_max_memory_bytes'] = torch.cuda.max_memory_allocated()

    if 'tensorflow' in sys.modules:
        import tensorflow as tf
        stats['tensorflow'] = {}
        if int(tf.__version__.split('.')[0]) < 2:
            if tf.test.is_gpu_available():
                stats['tensorflow']['gpu_max_memory_bytes'] = tf.contrib.memory_stats.MaxBytesInUse()
        else:
            if len(tf.config.experimental.list_physical_devices('GPU')) >= 1:
                logging.info("SEML stats: There is currently no way to get actual GPU memory usage in TensorFlow 2.")

    collection = get_collection(run.config['db_collection'])
    collection.update_one(
            {'_id': exp_id},
            {'$set': {'stats': stats}})
Exemple #6
0
def detect_killed(db_collection_name, print_detected=True):
    collection = get_collection(db_collection_name)
    exps = collection.find({'status': {'$in': [*States.PENDING, *States.RUNNING]},
                            '$or': [{'slurm.array_id': {'$exists': True}}, {'slurm.id': {'$exists': True}}]})
    running_jobs = get_slurm_arrays_tasks()
    old_running_jobs = get_slurm_jobs()  # Backwards compatibility
    nkilled = 0
    for exp in exps:
        exp_running = ('array_id' in exp['slurm'] and exp['slurm']['array_id'] in running_jobs
                       and (any(exp['slurm']['task_id'] in r for r in running_jobs[exp['slurm']['array_id']][0])
                            or exp['slurm']['task_id'] in running_jobs[exp['slurm']['array_id']][1]))
        exp_running |= ('id' in exp['slurm'] and exp['slurm']['id'] in old_running_jobs)
        if not exp_running:
            if 'stop_time' in exp:
                collection.update_one({'_id': exp['_id']}, {'$set': {'status': States.INTERRUPTED[0]}})
            else:
                nkilled += 1
                collection.update_one({'_id': exp['_id']}, {'$set': {'status': States.KILLED[0]}})
                try:
                    seml_config = exp['seml']
                    slurm_config = exp['slurm']
                    if 'output_file' in seml_config:
                        output_file = seml_config['output_file']
                    elif 'output_file' in slurm_config:
                        # Backward compatibility, we used to store the path in 'slurm'
                        output_file = slurm_config['output_file']
                    else:
                        continue
                    with open(output_file, 'r') as f:
                        all_lines = f.readlines()
                    collection.update_one({'_id': exp['_id']}, {'$set': {'fail_trace': all_lines[-4:]}})
                except IOError:
                    if 'output_file' in seml_config:
                        output_file = seml_config['output_file']
                    elif 'output_file' in slurm_config:
                        # Backward compatibility
                        output_file = slurm_config['output_file']
                    logging.warning(f"File {output_file} could not be read.")
    if print_detected:
        logging.info(f"Detected {nkilled} externally killed experiment{s_if(nkilled)}.")
Exemple #7
0
def get_results(db_collection_name,
                fields=['config', 'result'],
                to_data_frame=False,
                mongodb_config=None,
                suffix=None,
                states=['COMPLETED'],
                filter_dict=None,
                parallel=False):
    import pandas as pd

    if filter_dict is None:
        filter_dict = {}

    collection = get_collection(db_collection_name,
                                mongodb_config=mongodb_config,
                                suffix=suffix)

    if len(states) > 0:
        if 'status' in filter_dict:
            logging.warning(
                "'states' argument is not empty and will overwrite 'filter_dict['status']'."
            )
        filter_dict['status'] = {'$in': states}

    cursor = collection.find(filter_dict, fields)
    results = [
        x for x in tqdm(cursor, total=collection.count_documents(filter_dict))
    ]

    if parallel:
        from multiprocessing import Pool
        with Pool() as p:
            parsed = list(
                tqdm(p.imap(parse_jsonpickle, results), total=len(results)))
    else:
        parsed = [parse_jsonpickle(entry) for entry in tqdm(results)]
    if to_data_frame:
        parsed = pd.io.json.json_normalize(parsed, sep='.')
    return parsed
Exemple #8
0
def report_status(db_collection_name):
    detect_killed(db_collection_name, print_detected=False)
    collection = get_collection(db_collection_name)
    queued = collection.count_documents({'status': 'QUEUED'})
    pending = collection.count_documents({'status': 'PENDING'})
    failed = collection.count_documents({'status': 'FAILED'})
    killed = collection.count_documents({'status': 'KILLED'})
    interrupted = collection.count_documents({'status': 'INTERRUPTED'})
    running = collection.count_documents({'status': 'RUNNING'})
    completed = collection.count_documents({'status': 'COMPLETED'})
    title = f"********** Report for database collection '{db_collection_name}' **********"
    logging.info(title)
    logging.info(f"*     - {queued:3d} queued experiment{s_if(queued)}")
    logging.info(f"*     - {pending:3d} pending experiment{s_if(pending)}")
    logging.info(f"*     - {running:3d} running experiment{s_if(running)}")
    logging.info(
        f"*     - {completed:3d} completed experiment{s_if(completed)}")
    logging.info(
        f"*     - {interrupted:3d} interrupted experiment{s_if(interrupted)}")
    logging.info(f"*     - {failed:3d} failed experiment{s_if(failed)}")
    logging.info(f"*     - {killed:3d} killed experiment{s_if(killed)}")
    logging.info("*" * len(title))
Exemple #9
0
def cancel_experiments(db_collection_name, sacred_id, filter_states, batch_id, filter_dict):
    """
    Cancel experiments.

    Parameters
    ----------
    db_collection_name: str
        Database collection name.
    sacred_id: int or None
        ID of the experiment to cancel. If None, will use the other arguments to cancel possible multiple experiments.
    filter_states: list of strings or None
        List of statuses to filter for. Will cancel all jobs from the database collection
        with one of the given statuses.
    batch_id: int or None
        The ID of the batch of experiments to cancel. All experiments that are staged together (i.e. within the same
        command line call) have the same batch ID.
    filter_dict: dict or None
        Arbitrary filter dictionary to use for cancelling experiments. Any experiments whose database entries match all
        keys/values of the dictionary will be cancelled.

    Returns
    -------
    None

    """
    collection = get_collection(db_collection_name)
    if sacred_id is None:
        # no ID is provided: we check whether there are slurm jobs for which after this action no
        # RUNNING experiment remains. These slurm jobs can be killed altogether.
        # However, it is NOT possible right now to cancel a single experiment in a Slurm job with multiple
        # running experiments.
        try:
            if len({*States.PENDING, *States.RUNNING, *States.KILLED} & set(filter_states)) > 0:
                detect_killed(db_collection_name, print_detected=False)

            filter_dict = build_filter_dict(filter_states, batch_id, filter_dict)

            ncancel = collection.count_documents(filter_dict)
            if ncancel >= 10:
                if input(f"Cancelling {ncancel} experiment{s_if(ncancel)}. "
                         f"Are you sure? (y/n) ").lower() != "y":
                    exit()
            else:
                logging.info(f"Cancelling {ncancel} experiment{s_if(ncancel)}.")

            filter_dict_new = copy.deepcopy(filter_dict)
            filter_dict_new.update({'slurm.array_id': {'$exists': True}})
            exps = list(collection.find(filter_dict_new,
                                        {'_id': 1, 'status': 1, 'slurm.array_id': 1, 'slurm.task_id': 1}))
            # set of slurm IDs in the database
            slurm_ids = set([(e['slurm']['array_id'], e['slurm']['task_id']) for e in exps])
            # set of experiment IDs to be cancelled.
            exp_ids = set([e['_id'] for e in exps])
            to_cancel = set()

            # iterate over slurm IDs to check which slurm jobs can be cancelled altogether
            for (a_id, t_id) in slurm_ids:
                # find experiments RUNNING under the slurm job
                jobs_running = [e for e in exps
                                if (e['slurm']['array_id'] == a_id and e['slurm']['task_id'] == t_id
                                    and e['status'] in States.RUNNING)]
                running_exp_ids = set(e['_id'] for e in jobs_running)
                if len(running_exp_ids.difference(exp_ids)) == 0:
                    # there are no running jobs in this slurm job that should not be canceled.
                    to_cancel.add(f"{a_id}_{t_id}")

            # cancel all Slurm jobs for which no running experiment remains.
            if len(to_cancel) > 0:
                chunk_size = 100
                chunks = chunker(list(to_cancel), chunk_size)
                [subprocess.run(f"scancel {' '.join(chunk)}", shell=True, check=True) for chunk in chunks]

            # update database status and write the stop_time
            collection.update_many(filter_dict, {'$set': {"status": States.INTERRUPTED[0],
                                                          "stop_time": datetime.datetime.utcnow()}})
        except subprocess.CalledProcessError:
            logging.warning(f"One or multiple Slurm jobs were no longer running when I tried to cancel them.")
    else:
        logging.info(f"Cancelling experiment with ID {sacred_id}.")
        cancel_experiment_by_id(collection, sacred_id)
Exemple #10
0
def start_experiments(db_collection_name, local, sacred_id, batch_id, filter_dict,
                      num_exps, post_mortem, debug, debug_server,
                      output_to_console, no_file_output, steal_slurm,
                      no_worker, set_to_pending=True,
                      worker_gpus=None, worker_cpus=None, worker_environment_vars=None):

    output_to_file = not no_file_output
    launch_worker = not no_worker

    if debug or debug_server:
        num_exps = 1
        unobserved = True
        post_mortem = True
        output_to_console = True
        srun = True
        logging.root.setLevel(logging.VERBOSE)
    else:
        unobserved = False
        srun = False

    if local:
        check_compute_node()

    if not local:
        local_kwargs = {
                "--no-worker": no_worker,
                "--steal-slurm": steal_slurm,
                "--worker-gpus": worker_gpus,
                "--worker-cpus": worker_cpus,
                "--worker-environment-vars": worker_environment_vars}
        for key, val in local_kwargs.items():
            if val:
                raise ArgumentError(f"The argument '{key}' only works in local mode, not in Slurm mode.")
    if not local and not srun:
        non_sbatch_kwargs = {
                "--post-mortem": post_mortem,
                "--output-to-console": output_to_console}
        for key, val in non_sbatch_kwargs.items():
            if val:
                raise ArgumentError(f"The argument '{key}' does not work in regular Slurm mode. "
                                    "Remove the argument or use '--debug'.")

    if unobserved:
        set_to_pending = False

    if sacred_id is None:
        filter_dict = build_filter_dict([], batch_id, filter_dict)
        if 'status' not in filter_dict:
            filter_dict['status'] = {"$in": States.STAGED}
    else:
        filter_dict = {'_id': sacred_id}

    collection = get_collection(db_collection_name)

    staged_experiments = prepare_experiments(
            collection=collection, filter_dict=filter_dict, num_exps=num_exps,
            slurm=not local, set_to_pending=set_to_pending, print_pending=local)

    if debug_server:
        use_stored_sources = ('source_files' in staged_experiments[0]['seml'])
        if use_stored_sources:
            raise ArgumentError("Cannot use a debug server with source code that is loaded from the MongoDB. "
                                "Use the `--no-code-checkpoint` option when adding the experiment.")

    if not local:
        add_to_slurm_queue(collection=collection, exps_list=staged_experiments, unobserved=unobserved,
                           post_mortem=post_mortem, output_to_file=output_to_file,
                           output_to_console=output_to_console, srun=srun,
                           debug_server=debug_server)
    elif launch_worker:
        start_local_worker(collection=collection, num_exps=num_exps, filter_dict=filter_dict, unobserved=unobserved,
                           post_mortem=post_mortem, steal_slurm=steal_slurm,
                           output_to_console=output_to_console, output_to_file=output_to_file,
                           gpus=worker_gpus, cpus=worker_cpus, environment_variables=worker_environment_vars,
                           debug_server=debug_server)
Exemple #11
0
def print_command(db_collection_name, sacred_id, batch_id, filter_dict, num_exps,
                  worker_gpus=None, worker_cpus=None, worker_environment_vars=None):

    collection = get_collection(db_collection_name)

    if sacred_id is None:
        filter_dict = build_filter_dict([], batch_id, filter_dict)
        if 'status' not in filter_dict:
            filter_dict['status'] = {"$in": States.STAGED}
    else:
        filter_dict = {'_id': sacred_id}

    env_dict = get_environment_variables(worker_gpus, worker_cpus, worker_environment_vars)
    env_str = " ".join([f"{k}={v}" for k, v in env_dict.items()])
    if len(env_str) >= 1:
        env_str += " "

    orig_level = logging.root.level
    logging.root.setLevel(logging.VERBOSE)

    exps_list = list(collection.find(filter_dict, limit=num_exps))
    if len(exps_list) == 0:
        return

    exp = exps_list[0]
    _, exe, config = get_command_from_exp(exp, collection.name,
                                          verbose=logging.root.level <= logging.VERBOSE,
                                          unobserved=True, post_mortem=False)
    env = exp['seml']['conda_environment'] if 'conda_environment' in exp['seml'] else None

    logging.info("********** First experiment **********")
    logging.info(f"Executable: {exe}")
    if env is not None:
        logging.info(f"Anaconda environment: {env}")
    config.insert(0, 'with')
    config.append('--debug')

    # Remove double quotes, change single quotes to escaped double quotes
    config_vscode = [c.replace('"', '') for c in config]
    config_vscode = [c.replace("'", '\\"') for c in config_vscode]

    logging.info("\nArguments for VS Code debugger:")
    logging.info('["' + '", "'.join(config_vscode) + '"]')
    logging.info("Arguments for PyCharm debugger:")
    logging.info(" ".join(config))

    logging.info("\nCommand for post-mortem debugging:")
    interpreter, exe, config = get_command_from_exp(exps_list[0], collection.name,
                                                    verbose=logging.root.level <= logging.VERBOSE,
                                                    unobserved=True, post_mortem=True)
    logging.info(f"{env_str}{interpreter} {exe} with {' '.join(config)}")

    logging.info("\nCommand for remote debugging:")
    interpreter, exe, config = get_command_from_exp(exps_list[0], collection.name,
                                                    verbose=logging.root.level <= logging.VERBOSE,
                                                    unobserved=True, debug_server=True, print_info=False)
    logging.info(f"{env_str}{interpreter} {exe} with {' '.join(config)}")

    logging.info("\n********** All raw commands **********")
    logging.root.setLevel(orig_level)
    for exp in exps_list:
        interpreter, exe, config = get_command_from_exp(
                exp, collection.name, verbose=logging.root.level <= logging.VERBOSE)
        logging.info(f"{env_str}{interpreter} {exe} with {' '.join(config)}")
Exemple #12
0
def add_experiments(db_collection_name,
                    config_file,
                    force_duplicates,
                    no_hash=False,
                    no_sanity_check=False,
                    no_code_checkpoint=False):
    """
    Add configurations from a config file into the database.

    Parameters
    ----------
    db_collection_name: the MongoDB collection name.
    config_file: path to the YAML configuration.
    force_duplicates: if True, disable duplicate detection.
    no_hash: if True, disable hashing of the configurations for duplicate detection. This is much slower, so use only
        if you have a good reason to.
    no_sanity_check: if True, do not check the config for missing/unused arguments.
    no_code_checkpoint: if True, do not upload the experiment source code files to the MongoDB.

    Returns
    -------
    None
    """

    seml_config, slurm_config, experiment_config = read_config(config_file)

    # Use current Anaconda environment if not specified
    if 'conda_environment' not in seml_config:
        if 'CONDA_DEFAULT_ENV' in os.environ:
            seml_config['conda_environment'] = os.environ['CONDA_DEFAULT_ENV']
        else:
            seml_config['conda_environment'] = None

    # Set Slurm config with default parameters as fall-back option
    if slurm_config is None:
        slurm_config = {'sbatch_options': {}}
    for k, v in SETTINGS.SLURM_DEFAULT['sbatch_options'].items():
        if k not in slurm_config['sbatch_options']:
            slurm_config['sbatch_options'][k] = v
    del SETTINGS.SLURM_DEFAULT['sbatch_options']
    for k, v in SETTINGS.SLURM_DEFAULT.items():
        if k not in slurm_config:
            slurm_config[k] = v

    slurm_config['sbatch_options'] = remove_prepended_dashes(
        slurm_config['sbatch_options'])
    configs = generate_configs(experiment_config)
    collection = get_collection(db_collection_name)

    batch_id = get_max_in_collection(collection, "batch_id")
    if batch_id is None:
        batch_id = 1
    else:
        batch_id = batch_id + 1

    if seml_config['use_uploaded_sources'] and not no_code_checkpoint:
        uploaded_files = upload_sources(seml_config, collection, batch_id)
    else:
        uploaded_files = None

    if not no_sanity_check:
        check_config(seml_config['executable'],
                     seml_config['conda_environment'], configs)

    path, commit, dirty = get_git_info(seml_config['executable'])
    git_info = None
    if path is not None:
        git_info = {'path': path, 'commit': commit, 'dirty': dirty}

    use_hash = not no_hash
    if use_hash:
        configs = [{**c, **{'config_hash': make_hash(c)}} for c in configs]

    if not force_duplicates:
        len_before = len(configs)

        # First, check for duplicates withing the experiment configurations from the file.
        if not use_hash:
            # slow duplicate detection without hashes
            unique_configs = []
            for c in configs:
                if c not in unique_configs:
                    unique_configs.append(c)
            configs = unique_configs
        else:
            # fast duplicate detection using hashing.
            configs_dict = {c['config_hash']: c for c in configs}
            configs = [v for k, v in configs_dict.items()]

        len_after_deduplication = len(configs)
        # Now, check for duplicate configurations in the database.
        configs = filter_experiments(collection, configs)
        len_after = len(configs)
        if len_after_deduplication != len_before:
            logging.info(
                f"{len_before - len_after_deduplication} of {len_before} experiment{s_if(len_before)} were "
                f"duplicates. Adding only the {len_after_deduplication} unique configurations."
            )
        if len_after != len_after_deduplication:
            logging.info(
                f"{len_after_deduplication - len_after} of {len_after_deduplication} "
                f"experiment{s_if(len_before)} were already found in the database. They were not added again."
            )

    # Create an index on the config hash. If the index is already present, this simply does nothing.
    collection.create_index("config_hash")
    # Add the configurations to the database with STAGED status.
    if len(configs) > 0:
        add_configs(collection, seml_config, slurm_config, configs,
                    uploaded_files, git_info)
Exemple #13
0
                        help="Run the experiments without Sacred observers.")
    parser.add_argument("--post-mortem",
                        default=False,
                        type=lambda x: (str(x).lower() == 'true'),
                        help="Activate post-mortem debugging with pdb.")
    parser.add_argument(
        "--stored-sources-dir",
        default=None,
        type=str,
        help="Load source files into this directory before starting.")
    args = parser.parse_args()

    exp_id = args.experiment_id
    db_collection_name = args.db_collection_name

    collection = get_collection(db_collection_name)

    exp = collection.find_one({'_id': exp_id})
    use_stored_sources = args.stored_sources_dir is not None
    if use_stored_sources and not os.listdir(args.stored_sources_dir):
        assert "source_files" in exp['seml'],\
               "--stored-sources-dir was supplied but queued experiment does not contain stored source files."
        load_sources_from_db(exp,
                             collection,
                             to_directory=args.stored_sources_dir)

    exe, config = get_command_from_exp(exp,
                                       db_collection_name,
                                       verbose=args.verbose,
                                       unobserved=args.unobserved,
                                       post_mortem=args.post_mortem)
Exemple #14
0
def start_experiments(db_collection_name,
                      local,
                      sacred_id,
                      batch_id,
                      filter_dict,
                      num_exps,
                      post_mortem,
                      debug,
                      debug_server,
                      print_command,
                      output_to_console,
                      no_file_output,
                      steal_slurm,
                      no_worker,
                      set_to_pending=True,
                      worker_gpus=None,
                      worker_cpus=None,
                      worker_environment_vars=None):

    use_slurm = not local
    output_to_file = not no_file_output
    launch_worker = not no_worker

    if debug or debug_server:
        num_exps = 1
        unobserved = True
        post_mortem = True
        output_to_console = True
        srun = True
        logging.root.setLevel(logging.VERBOSE)
    else:
        unobserved = False
        srun = False

    if not local:
        local_kwargs = {
            "--no-worker": no_worker,
            "--steal-slurm": steal_slurm,
            "--worker-gpus": worker_gpus,
            "--worker-cpus": worker_cpus,
            "--worker-environment-vars": worker_environment_vars
        }
        for key, val in local_kwargs.items():
            if val:
                raise ArgumentError(
                    f"The argument '{key}' only works in local mode, not in Slurm mode."
                )
    if not local and not srun:
        non_sbatch_kwargs = {
            "--post-mortem": post_mortem,
            "--output-to-console": output_to_console
        }
        for key, val in non_sbatch_kwargs.items():
            if val:
                raise ArgumentError(
                    f"The argument '{key}' does not work in regular Slurm mode. "
                    "Remove the argument or use '--debug'.")

    if filter_dict is None:
        filter_dict = {}

    if unobserved:
        set_to_pending = False

    if worker_environment_vars is None:
        worker_environment_vars = {}

    if sacred_id is None:
        filter_dict = build_filter_dict([], batch_id, filter_dict)
    else:
        # if we have a specific sacred ID, we ignore the state of the experiment and run it in any case.
        all_states = (States.PENDING + States.STAGED + States.RUNNING +
                      States.FAILED + States.INTERRUPTED + States.KILLED +
                      States.COMPLETED)
        filter_dict.update({'_id': sacred_id, "status": {"$in": all_states}})

    collection = get_collection(db_collection_name)

    if print_command:
        print_commands(collection,
                       unobserved=unobserved,
                       post_mortem=post_mortem,
                       debug_server=debug_server,
                       num_exps=num_exps,
                       filter_dict=filter_dict)
        return

    staged_experiments = prepare_staged_experiments(
        collection=collection,
        filter_dict=filter_dict,
        num_exps=num_exps,
        slurm=use_slurm,
        set_to_pending=set_to_pending,
        print_pending=not use_slurm)

    if use_slurm:
        add_to_slurm_queue(collection=collection,
                           exps_list=staged_experiments,
                           unobserved=unobserved,
                           post_mortem=post_mortem,
                           output_to_file=output_to_file,
                           output_to_console=output_to_console,
                           srun=srun,
                           debug_server=debug_server)
    elif launch_worker:
        start_local_worker(collection=collection,
                           num_exps=num_exps,
                           filter_dict=filter_dict,
                           unobserved=unobserved,
                           post_mortem=post_mortem,
                           steal_slurm=steal_slurm,
                           output_to_console=output_to_console,
                           output_to_file=output_to_file,
                           gpus=worker_gpus,
                           cpus=worker_cpus,
                           environment_variables=worker_environment_vars,
                           debug_server=debug_server)
Exemple #15
0
def start_jobs(db_collection_name,
               slurm=True,
               unobserved=False,
               post_mortem=False,
               num_exps=-1,
               filter_dict=None,
               dry_run=False,
               output_to_file=True):
    """Pull queued experiments from the database and run them.

    Parameters
    ----------
    db_collection_name: str
        Name of the collection in the MongoDB.
    slurm: bool
        Use the Slurm cluster.
    unobserved: bool
        Disable all Sacred observers (nothing written to MongoDB).
    post_mortem: bool
        Activate post-mortem debugging.
    num_exps: int, default: -1
        If >0, will only submit the specified number of experiments to the cluster.
        This is useful when you only want to test your setup.
    filter_dict: dict
        Dictionary for filtering the entries in the collection.
    dry_run: bool
        Just return the executables and configurations instead of running them.
    output_to_file: bool
        Pipe all output (stdout and stderr) to an output file.
        Can only be False if slurm is False.

    Returns
    -------
    None
    """
    if filter_dict is None:
        filter_dict = {}

    collection = get_collection(db_collection_name)

    if unobserved and not slurm and '_id' in filter_dict:
        query_dict = {}
    else:
        query_dict = {'status': {"$in": ['QUEUED']}}
    query_dict.update(filter_dict)

    if collection.count_documents(query_dict) <= 0:
        logging.error("No queued experiments.")
        return

    exps_full = list(collection.find(query_dict))

    nexps = num_exps if num_exps > 0 else len(exps_full)
    exps_list = exps_full[:nexps]

    if dry_run:
        configs = []
        for exp in exps_list:
            exe, config = get_command_from_exp(
                exp,
                db_collection_name,
                verbose=logging.root.level <= logging.VERBOSE,
                unobserved=unobserved,
                post_mortem=post_mortem)
            if 'conda_environment' in exp['seml']:
                configs.append((exe, exp['seml']['conda_environment'], config))
            else:
                configs.append((exe, None, config))
        return configs
    elif slurm:
        if not output_to_file:
            logging.error("Output cannot be written to stdout in Slurm mode. "
                          "Remove the '--output-to-console' argument.")
            sys.exit(1)
        exp_chunks = chunk_list(exps_list)
        exp_arrays = batch_chunks(exp_chunks)
        njobs = len(exp_chunks)
        narrays = len(exp_arrays)

        logging.info(
            f"Starting {nexps} experiment{s_if(nexps)} in "
            f"{njobs} Slurm job{s_if(njobs)} in {narrays} Slurm job array{s_if(narrays)}."
        )

        for exp_array in exp_arrays:
            job_name = get_exp_name(exp_array[0][0], collection.name)
            output_dir_path = get_output_dir_path(exp_array[0][0])
            slurm_config = exp_array[0][0]['slurm']
            del slurm_config['experiments_per_job']
            start_slurm_job(collection,
                            exp_array,
                            unobserved,
                            post_mortem,
                            name=job_name,
                            output_dir_path=output_dir_path,
                            **slurm_config)
    else:
        login_node_name = 'fs'
        if login_node_name in os.uname()[1]:
            logging.error(
                "Refusing to run a compute experiment on a login node. "
                "Please use Slurm or a compute node.")
            sys.exit(1)
        [get_output_dir_path(exp)
         for exp in exps_list]  # Check if output dir exists
        logging.info(
            f'Starting local worker thread that will run up to {nexps} experiment{s_if(nexps)}, '
            f'until no queued experiments remain.')
        if not unobserved:
            collection.update_many(
                {'_id': {
                    '$in': [e['_id'] for e in exps_list]
                }}, {"$set": {
                    "status": "PENDING"
                }})
        num_exceptions = 0
        tq = tqdm(enumerate(exps_list))
        for i_exp, exp in tq:
            if output_to_file:
                output_dir_path = get_output_dir_path(exp)
            else:
                output_dir_path = None
            success = start_local_job(collection, exp, unobserved, post_mortem,
                                      output_dir_path)
            if success is False:
                num_exceptions += 1
            tq.set_postfix(failed=f"{num_exceptions}/{i_exp} experiments")