예제 #1
0
def reset_experiments(db_collection_name, sacred_id, filter_states, batch_id, filter_dict):
    collection = get_collection(db_collection_name)

    if sacred_id is None:
        if len({*States.PENDING, *States.RUNNING, *States.KILLED} & set(filter_states)) > 0:
            detect_killed(db_collection_name, print_detected=False)

        if isinstance(filter_states, str):
            filter_states = [filter_states]

        filter_dict = build_filter_dict(filter_states, batch_id, filter_dict)

        nreset = collection.count_documents(filter_dict)
        exps = collection.find(filter_dict)

        if nreset >= 10:
            if input(f"Resetting the state of {nreset} experiment{s_if(nreset)}. "
                     f"Are you sure? (y/n) ").lower() != "y":
                exit()
        else:
            logging.info(f"Resetting the state of {nreset} experiment{s_if(nreset)}.")
        for exp in exps:
            reset_single_experiment(collection, exp)
    else:
        exp = collection.find_one({'_id': sacred_id})
        if exp is None:
            raise MongoDBError(f"No experiment found with ID {sacred_id}.")
        else:
            logging.info(f"Resetting the state of experiment with ID {sacred_id}.")
            reset_single_experiment(collection, exp)
예제 #2
0
def delete_experiments(db_collection_name, sacred_id, filter_states, batch_id, filter_dict):
    collection = get_collection(db_collection_name)
    if sacred_id is None:
        if len({*States.PENDING, *States.RUNNING, *States.KILLED} & set(filter_states)) > 0:
            detect_killed(db_collection_name, print_detected=False)

        filter_dict = build_filter_dict(filter_states, batch_id, filter_dict)
        ndelete = collection.count_documents(filter_dict)
        batch_ids = collection.find(filter_dict, {'batch_id'})
        batch_ids_in_del = set([x['batch_id'] for x in batch_ids])

        if ndelete >= 10:
            if input(f"Deleting {ndelete} configuration{s_if(ndelete)} from database collection. "
                     f"Are you sure? (y/n) ").lower() != "y":
                exit()
        else:
            logging.info(f"Deleting {ndelete} configuration{s_if(ndelete)} from database collection.")
        collection.delete_many(filter_dict)
    else:
        exp = collection.find_one({'_id': sacred_id})
        if exp is None:
            raise MongoDBError(f"No experiment found with ID {sacred_id}.")
        else:
            logging.info(f"Deleting experiment with ID {sacred_id}.")
            batch_ids_in_del = set([exp['batch_id']])
            collection.delete_one({'_id': sacred_id})

    if len(batch_ids_in_del) > 0:
        # clean up the uploaded sources if no experiments of a batch remain
        delete_orphaned_sources(collection, batch_ids_in_del)
예제 #3
0
def start_experiments(db_collection_name, local, sacred_id, batch_id,
                      filter_dict, num_exps, unobserved, post_mortem, debug,
                      dry_run, output_to_console):
    use_slurm = not local
    output_to_file = not output_to_console

    if debug:
        num_exps = 1
        use_slurm = False
        unobserved = True
        post_mortem = True
        output_to_file = False
        logging.root.setLevel(logging.VERBOSE)

    if sacred_id is None:
        filter_dict = build_filter_dict([], batch_id, filter_dict)
    else:
        filter_dict = {'_id': sacred_id}

    if dry_run:
        print_commands(db_collection_name,
                       unobserved=unobserved,
                       post_mortem=post_mortem,
                       num_exps=num_exps,
                       filter_dict=filter_dict)
    else:
        start_jobs(db_collection_name,
                   slurm=use_slurm,
                   unobserved=unobserved,
                   post_mortem=post_mortem,
                   num_exps=num_exps,
                   filter_dict=filter_dict,
                   dry_run=dry_run,
                   output_to_file=output_to_file)
예제 #4
0
def cancel_experiments(db_collection_name, sacred_id, filter_states, batch_id, filter_dict):
    """
    Cancel experiments.

    Parameters
    ----------
    db_collection_name: str
        Database collection name.
    sacred_id: int or None
        ID of the experiment to cancel. If None, will use the other arguments to cancel possible multiple experiments.
    filter_states: list of strings or None
        List of statuses to filter for. Will cancel all jobs from the database collection
        with one of the given statuses.
    batch_id: int or None
        The ID of the batch of experiments to cancel. All experiments that are staged together (i.e. within the same
        command line call) have the same batch ID.
    filter_dict: dict or None
        Arbitrary filter dictionary to use for cancelling experiments. Any experiments whose database entries match all
        keys/values of the dictionary will be cancelled.

    Returns
    -------
    None

    """
    collection = get_collection(db_collection_name)
    if sacred_id is None:
        # no ID is provided: we check whether there are slurm jobs for which after this action no
        # RUNNING experiment remains. These slurm jobs can be killed altogether.
        # However, it is NOT possible right now to cancel a single experiment in a Slurm job with multiple
        # running experiments.
        try:
            if len({*States.PENDING, *States.RUNNING, *States.KILLED} & set(filter_states)) > 0:
                detect_killed(db_collection_name, print_detected=False)

            filter_dict = build_filter_dict(filter_states, batch_id, filter_dict)

            ncancel = collection.count_documents(filter_dict)
            if ncancel >= 10:
                if input(f"Cancelling {ncancel} experiment{s_if(ncancel)}. "
                         f"Are you sure? (y/n) ").lower() != "y":
                    exit()
            else:
                logging.info(f"Cancelling {ncancel} experiment{s_if(ncancel)}.")

            filter_dict_new = copy.deepcopy(filter_dict)
            filter_dict_new.update({'slurm.array_id': {'$exists': True}})
            exps = list(collection.find(filter_dict_new,
                                        {'_id': 1, 'status': 1, 'slurm.array_id': 1, 'slurm.task_id': 1}))
            # set of slurm IDs in the database
            slurm_ids = set([(e['slurm']['array_id'], e['slurm']['task_id']) for e in exps])
            # set of experiment IDs to be cancelled.
            exp_ids = set([e['_id'] for e in exps])
            to_cancel = set()

            # iterate over slurm IDs to check which slurm jobs can be cancelled altogether
            for (a_id, t_id) in slurm_ids:
                # find experiments RUNNING under the slurm job
                jobs_running = [e for e in exps
                                if (e['slurm']['array_id'] == a_id and e['slurm']['task_id'] == t_id
                                    and e['status'] in States.RUNNING)]
                running_exp_ids = set(e['_id'] for e in jobs_running)
                if len(running_exp_ids.difference(exp_ids)) == 0:
                    # there are no running jobs in this slurm job that should not be canceled.
                    to_cancel.add(f"{a_id}_{t_id}")

            # cancel all Slurm jobs for which no running experiment remains.
            if len(to_cancel) > 0:
                chunk_size = 100
                chunks = chunker(list(to_cancel), chunk_size)
                [subprocess.run(f"scancel {' '.join(chunk)}", shell=True, check=True) for chunk in chunks]

            # update database status and write the stop_time
            collection.update_many(filter_dict, {'$set': {"status": States.INTERRUPTED[0],
                                                          "stop_time": datetime.datetime.utcnow()}})
        except subprocess.CalledProcessError:
            logging.warning(f"One or multiple Slurm jobs were no longer running when I tried to cancel them.")
    else:
        logging.info(f"Cancelling experiment with ID {sacred_id}.")
        cancel_experiment_by_id(collection, sacred_id)
예제 #5
0
def start_experiments(db_collection_name, local, sacred_id, batch_id, filter_dict,
                      num_exps, post_mortem, debug, debug_server,
                      output_to_console, no_file_output, steal_slurm,
                      no_worker, set_to_pending=True,
                      worker_gpus=None, worker_cpus=None, worker_environment_vars=None):

    output_to_file = not no_file_output
    launch_worker = not no_worker

    if debug or debug_server:
        num_exps = 1
        unobserved = True
        post_mortem = True
        output_to_console = True
        srun = True
        logging.root.setLevel(logging.VERBOSE)
    else:
        unobserved = False
        srun = False

    if local:
        check_compute_node()

    if not local:
        local_kwargs = {
                "--no-worker": no_worker,
                "--steal-slurm": steal_slurm,
                "--worker-gpus": worker_gpus,
                "--worker-cpus": worker_cpus,
                "--worker-environment-vars": worker_environment_vars}
        for key, val in local_kwargs.items():
            if val:
                raise ArgumentError(f"The argument '{key}' only works in local mode, not in Slurm mode.")
    if not local and not srun:
        non_sbatch_kwargs = {
                "--post-mortem": post_mortem,
                "--output-to-console": output_to_console}
        for key, val in non_sbatch_kwargs.items():
            if val:
                raise ArgumentError(f"The argument '{key}' does not work in regular Slurm mode. "
                                    "Remove the argument or use '--debug'.")

    if unobserved:
        set_to_pending = False

    if sacred_id is None:
        filter_dict = build_filter_dict([], batch_id, filter_dict)
        if 'status' not in filter_dict:
            filter_dict['status'] = {"$in": States.STAGED}
    else:
        filter_dict = {'_id': sacred_id}

    collection = get_collection(db_collection_name)

    staged_experiments = prepare_experiments(
            collection=collection, filter_dict=filter_dict, num_exps=num_exps,
            slurm=not local, set_to_pending=set_to_pending, print_pending=local)

    if debug_server:
        use_stored_sources = ('source_files' in staged_experiments[0]['seml'])
        if use_stored_sources:
            raise ArgumentError("Cannot use a debug server with source code that is loaded from the MongoDB. "
                                "Use the `--no-code-checkpoint` option when adding the experiment.")

    if not local:
        add_to_slurm_queue(collection=collection, exps_list=staged_experiments, unobserved=unobserved,
                           post_mortem=post_mortem, output_to_file=output_to_file,
                           output_to_console=output_to_console, srun=srun,
                           debug_server=debug_server)
    elif launch_worker:
        start_local_worker(collection=collection, num_exps=num_exps, filter_dict=filter_dict, unobserved=unobserved,
                           post_mortem=post_mortem, steal_slurm=steal_slurm,
                           output_to_console=output_to_console, output_to_file=output_to_file,
                           gpus=worker_gpus, cpus=worker_cpus, environment_variables=worker_environment_vars,
                           debug_server=debug_server)
예제 #6
0
def print_command(db_collection_name, sacred_id, batch_id, filter_dict, num_exps,
                  worker_gpus=None, worker_cpus=None, worker_environment_vars=None):

    collection = get_collection(db_collection_name)

    if sacred_id is None:
        filter_dict = build_filter_dict([], batch_id, filter_dict)
        if 'status' not in filter_dict:
            filter_dict['status'] = {"$in": States.STAGED}
    else:
        filter_dict = {'_id': sacred_id}

    env_dict = get_environment_variables(worker_gpus, worker_cpus, worker_environment_vars)
    env_str = " ".join([f"{k}={v}" for k, v in env_dict.items()])
    if len(env_str) >= 1:
        env_str += " "

    orig_level = logging.root.level
    logging.root.setLevel(logging.VERBOSE)

    exps_list = list(collection.find(filter_dict, limit=num_exps))
    if len(exps_list) == 0:
        return

    exp = exps_list[0]
    _, exe, config = get_command_from_exp(exp, collection.name,
                                          verbose=logging.root.level <= logging.VERBOSE,
                                          unobserved=True, post_mortem=False)
    env = exp['seml']['conda_environment'] if 'conda_environment' in exp['seml'] else None

    logging.info("********** First experiment **********")
    logging.info(f"Executable: {exe}")
    if env is not None:
        logging.info(f"Anaconda environment: {env}")
    config.insert(0, 'with')
    config.append('--debug')

    # Remove double quotes, change single quotes to escaped double quotes
    config_vscode = [c.replace('"', '') for c in config]
    config_vscode = [c.replace("'", '\\"') for c in config_vscode]

    logging.info("\nArguments for VS Code debugger:")
    logging.info('["' + '", "'.join(config_vscode) + '"]')
    logging.info("Arguments for PyCharm debugger:")
    logging.info(" ".join(config))

    logging.info("\nCommand for post-mortem debugging:")
    interpreter, exe, config = get_command_from_exp(exps_list[0], collection.name,
                                                    verbose=logging.root.level <= logging.VERBOSE,
                                                    unobserved=True, post_mortem=True)
    logging.info(f"{env_str}{interpreter} {exe} with {' '.join(config)}")

    logging.info("\nCommand for remote debugging:")
    interpreter, exe, config = get_command_from_exp(exps_list[0], collection.name,
                                                    verbose=logging.root.level <= logging.VERBOSE,
                                                    unobserved=True, debug_server=True, print_info=False)
    logging.info(f"{env_str}{interpreter} {exe} with {' '.join(config)}")

    logging.info("\n********** All raw commands **********")
    logging.root.setLevel(orig_level)
    for exp in exps_list:
        interpreter, exe, config = get_command_from_exp(
                exp, collection.name, verbose=logging.root.level <= logging.VERBOSE)
        logging.info(f"{env_str}{interpreter} {exe} with {' '.join(config)}")
예제 #7
0
def start_experiments(db_collection_name,
                      local,
                      sacred_id,
                      batch_id,
                      filter_dict,
                      num_exps,
                      post_mortem,
                      debug,
                      debug_server,
                      print_command,
                      output_to_console,
                      no_file_output,
                      steal_slurm,
                      no_worker,
                      set_to_pending=True,
                      worker_gpus=None,
                      worker_cpus=None,
                      worker_environment_vars=None):

    use_slurm = not local
    output_to_file = not no_file_output
    launch_worker = not no_worker

    if debug or debug_server:
        num_exps = 1
        unobserved = True
        post_mortem = True
        output_to_console = True
        srun = True
        logging.root.setLevel(logging.VERBOSE)
    else:
        unobserved = False
        srun = False

    if not local:
        local_kwargs = {
            "--no-worker": no_worker,
            "--steal-slurm": steal_slurm,
            "--worker-gpus": worker_gpus,
            "--worker-cpus": worker_cpus,
            "--worker-environment-vars": worker_environment_vars
        }
        for key, val in local_kwargs.items():
            if val:
                raise ArgumentError(
                    f"The argument '{key}' only works in local mode, not in Slurm mode."
                )
    if not local and not srun:
        non_sbatch_kwargs = {
            "--post-mortem": post_mortem,
            "--output-to-console": output_to_console
        }
        for key, val in non_sbatch_kwargs.items():
            if val:
                raise ArgumentError(
                    f"The argument '{key}' does not work in regular Slurm mode. "
                    "Remove the argument or use '--debug'.")

    if filter_dict is None:
        filter_dict = {}

    if unobserved:
        set_to_pending = False

    if worker_environment_vars is None:
        worker_environment_vars = {}

    if sacred_id is None:
        filter_dict = build_filter_dict([], batch_id, filter_dict)
    else:
        # if we have a specific sacred ID, we ignore the state of the experiment and run it in any case.
        all_states = (States.PENDING + States.STAGED + States.RUNNING +
                      States.FAILED + States.INTERRUPTED + States.KILLED +
                      States.COMPLETED)
        filter_dict.update({'_id': sacred_id, "status": {"$in": all_states}})

    collection = get_collection(db_collection_name)

    if print_command:
        print_commands(collection,
                       unobserved=unobserved,
                       post_mortem=post_mortem,
                       debug_server=debug_server,
                       num_exps=num_exps,
                       filter_dict=filter_dict)
        return

    staged_experiments = prepare_staged_experiments(
        collection=collection,
        filter_dict=filter_dict,
        num_exps=num_exps,
        slurm=use_slurm,
        set_to_pending=set_to_pending,
        print_pending=not use_slurm)

    if use_slurm:
        add_to_slurm_queue(collection=collection,
                           exps_list=staged_experiments,
                           unobserved=unobserved,
                           post_mortem=post_mortem,
                           output_to_file=output_to_file,
                           output_to_console=output_to_console,
                           srun=srun,
                           debug_server=debug_server)
    elif launch_worker:
        start_local_worker(collection=collection,
                           num_exps=num_exps,
                           filter_dict=filter_dict,
                           unobserved=unobserved,
                           post_mortem=post_mortem,
                           steal_slurm=steal_slurm,
                           output_to_console=output_to_console,
                           output_to_file=output_to_file,
                           gpus=worker_gpus,
                           cpus=worker_cpus,
                           environment_variables=worker_environment_vars,
                           debug_server=debug_server)