Example #1
0
class Agent(object):
    def __init__(self,
                 sweep_id,
                 project=None,
                 entity=None,
                 function=None,
                 count=None):
        self._sweep_path = sweep_id
        self._sweep_id = None
        self._project = project
        self._entity = entity
        self._function = function
        self._count = count
        # glob_config = os.path.expanduser('~/.config/wandb/settings')
        # loc_config = 'wandb/settings'
        # files = (glob_config, loc_config)
        self._api = InternalApi()
        self._agent_id = None

    def register(self):
        agent = self._api.register_agent(socket.gethostname(),
                                         sweep_id=self._sweep_id)
        self._agent_id = agent["id"]

    def check_queue(self):
        run_status = dict()
        commands = self._api.agent_heartbeat(self._agent_id, {}, run_status)
        if not commands:
            return
        command = commands[0]
        job = Job(command)
        return job

    def run_job(self, job):
        run_id = job.run_id

        config_file = os.path.join("wandb", "sweep-" + self._sweep_id,
                                   "config-" + run_id + ".yaml")
        config_util.save_config_file_from_dict(config_file, job.config)
        os.environ[wandb.env.RUN_ID] = run_id
        os.environ[wandb.env.CONFIG_PATHS] = config_file
        os.environ[wandb.env.SWEEP_ID] = self._sweep_id
        wandb.setup(_reset=True)

        print("wandb: Agent Starting Run: {} with config:\n".format(run_id) +
              "\n".join([
                  "\t{}: {}".format(k, v["value"])
                  for k, v in job.config.items()
              ]))
        try:
            self._function()
            if wandb.run:
                wandb.join()
        except KeyboardInterrupt as e:
            print("Keyboard interrupt", e)
            return True
        except Exception as e:
            print("Problem", e)
            return True

    def setup(self):
        parts = dict(entity=self._entity,
                     project=self._project,
                     name=self._sweep_path)
        err = util.parse_sweep_id(parts)
        if err:
            wandb.termerror(err)
            return
        entity = parts.get("entity") or self._entity
        project = parts.get("project") or self._project
        sweep_id = parts.get("name") or self._sweep_id

        if entity:
            wandb.env.set_entity(entity)
        if project:
            wandb.env.set_project(project)
        if sweep_id:
            self._sweep_id = sweep_id
        self.register()

    def loop(self):
        self.setup()
        count = 0
        while True:
            job = self.check_queue()
            if not job:
                time.sleep(20)
                continue
            if job.done():
                break
            count += 1
            stop = self.run_job(job)
            if stop:
                break
            if self._count and count >= self._count:
                break
            time.sleep(5)
Example #2
0
class Agent(object):

    FLAPPING_MAX_SECONDS = 60
    FLAPPING_MAX_FAILURES = 3
    MAX_INITIAL_FAILURES = 5

    def __init__(self,
                 sweep_id=None,
                 project=None,
                 entity=None,
                 function=None,
                 count=None):
        self._sweep_path = sweep_id
        self._sweep_id = None
        self._project = project
        self._entity = entity
        self._function = function
        self._count = count
        # glob_config = os.path.expanduser('~/.config/wandb/settings')
        # loc_config = 'wandb/settings'
        # files = (glob_config, loc_config)
        self._api = InternalApi()
        self._agent_id = None
        self._max_initial_failures = wandb.env.get_agent_max_initial_failures(
            self.MAX_INITIAL_FAILURES)
        # if the directory to log to is not set, set it
        if os.environ.get(wandb.env.DIR) is None:
            os.environ[wandb.env.DIR] = os.path.abspath(os.getcwd())

    def _init(self):
        # These are not in constructor so that Agent instance can be rerun
        self._run_threads = {}
        self._run_status = {}
        self._queue = queue.Queue()
        self._exit_flag = False
        self._exceptions = {}
        self._start_time = time.time()

    def _register(self):
        logger.debug("Agent._register()")
        agent = self._api.register_agent(socket.gethostname(),
                                         sweep_id=self._sweep_id)
        self._agent_id = agent["id"]
        logger.debug("agent_id = {}".format(self._agent_id))

    def _setup(self):
        logger.debug("Agent._setup()")
        self._init()
        parts = dict(entity=self._entity,
                     project=self._project,
                     name=self._sweep_path)
        err = util.parse_sweep_id(parts)
        if err:
            wandb.termerror(err)
            return
        entity = parts.get("entity") or self._entity
        project = parts.get("project") or self._project
        sweep_id = parts.get("name") or self._sweep_id
        if sweep_id:
            os.environ[wandb.env.SWEEP_ID] = sweep_id
        if entity:
            wandb.env.set_entity(entity)
        if project:
            wandb.env.set_project(project)
        if sweep_id:
            self._sweep_id = sweep_id
        self._register()

    def _stop_run(self, run_id):
        logger.debug("Stopping run {}.".format(run_id))
        self._run_status[run_id] = RunStatus.STOPPED
        thread = self._run_threads.get(run_id)
        if thread:
            _terminate_thread(thread)

    def _stop_all_runs(self):
        logger.debug("Stopping all runs.")
        for run in list(self._run_threads.keys()):
            self._stop_run(run)

    def _exit(self):
        self._stop_all_runs()
        self._exit_flag = True
        # _terminate_thread(self._main_thread)

    def _heartbeat(self):
        while True:
            if self._exit_flag:
                return
            # if not self._main_thread.is_alive():
            #     return
            run_status = {
                run: True
                for run, status in self._run_status.items()
                if status in (RunStatus.QUEUED, RunStatus.RUNNING)
            }
            commands = self._api.agent_heartbeat(self._agent_id, {},
                                                 run_status)
            if not commands:
                continue
            job = Job(commands[0])
            logger.debug("Job received: {}".format(job))
            if job.type == "run":
                self._queue.put(job)
                self._run_status[job.run_id] = RunStatus.QUEUED
            elif job.type == "stop":
                self._stop_run(job.run_id)
            elif job.type == "exit":
                self._exit()
                return
            time.sleep(5)

    def _run_jobs_from_queue(self):  # noqa:C901
        global _INSTANCES
        _INSTANCES += 1
        try:
            waiting = False
            count = 0
            while True:
                if self._exit_flag:
                    return
                try:
                    try:
                        job = self._queue.get(timeout=5)
                        if self._exit_flag:
                            logger.debug("Exiting main loop due to exit flag.")
                            wandb.termlog("Sweep Agent: Exiting.")
                            return
                    except queue.Empty:
                        if not waiting:
                            logger.debug("Paused.")
                            wandb.termlog("Sweep Agent: Waiting for job.")
                            waiting = True
                        time.sleep(5)
                        if self._exit_flag:
                            logger.debug("Exiting main loop due to exit flag.")
                            wandb.termlog("Sweep Agent: Exiting.")
                            return
                        continue
                    if waiting:
                        logger.debug("Resumed.")
                        wandb.termlog("Job received.")
                        waiting = False
                    count += 1
                    run_id = job.run_id
                    if self._run_status[run_id] == RunStatus.STOPPED:
                        continue
                    logger.debug(
                        "Spawning new thread for run {}.".format(run_id))
                    thread = threading.Thread(target=self._run_job,
                                              args=(job, ))
                    self._run_threads[run_id] = thread
                    thread.start()
                    self._run_status[run_id] = RunStatus.RUNNING
                    thread.join()
                    logger.debug("Thread joined for run {}.".format(run_id))
                    if self._run_status[run_id] == RunStatus.RUNNING:
                        self._run_status[run_id] = RunStatus.DONE
                    elif self._run_status[run_id] == RunStatus.ERRORED:
                        exc = self._exceptions[run_id]
                        logger.error("Run {} errored: {}".format(
                            run_id, repr(exc)))
                        wandb.termerror("Run {} errored: {}".format(
                            run_id, repr(exc)))
                        if os.getenv(
                                wandb.env.AGENT_DISABLE_FLAPPING) == "true":
                            self._exit_flag = True
                            return
                        elif (time.time() - self._start_time <
                              self.FLAPPING_MAX_SECONDS) and (
                                  len(self._exceptions) >=
                                  self.FLAPPING_MAX_FAILURES):
                            msg = "Detected {} failed runs in the first {} seconds, killing sweep.".format(
                                self.FLAPPING_MAX_FAILURES,
                                self.FLAPPING_MAX_SECONDS)
                            logger.error(msg)
                            wandb.termerror(msg)
                            wandb.termlog(
                                "To disable this check set WANDB_AGENT_DISABLE_FLAPPING=true"
                            )
                            self._exit_flag = True
                            return
                        if (self._max_initial_failures < len(self._exceptions)
                                and len(self._exceptions) >= count):
                            msg = "Detected {} failed runs in a row at start, killing sweep.".format(
                                self._max_initial_failures)
                            logger.error(msg)
                            wandb.termerror(msg)
                            wandb.termlog(
                                "To change this value set WANDB_AGENT_MAX_INITIAL_FAILURES=val"
                            )
                            self._exit_flag = True
                            return
                    if self._count and self._count == count:
                        logger.debug(
                            "Exiting main loop because max count reached.")
                        self._exit_flag = True
                        return
                except KeyboardInterrupt:
                    logger.debug("Ctrl + C detected. Stopping sweep.")
                    wandb.termlog("Ctrl + C detected. Stopping sweep.")
                    self._exit()
                    return
                except Exception as e:
                    if self._exit_flag:
                        logger.debug("Exiting main loop due to exit flag.")
                        wandb.termlog("Sweep Agent: Killed.")
                        return
                    else:
                        raise e
        finally:
            _INSTANCES -= 1

    def _run_job(self, job):
        try:
            run_id = job.run_id

            config_file = os.path.join("wandb", "sweep-" + self._sweep_id,
                                       "config-" + run_id + ".yaml")
            os.environ[wandb.env.RUN_ID] = run_id
            os.environ[wandb.env.CONFIG_PATHS] = os.path.join(
                os.environ[wandb.env.DIR], config_file)
            wandb.wandb_lib.config_util.save_config_file_from_dict(
                os.environ[wandb.env.CONFIG_PATHS], job.config)
            os.environ[wandb.env.SWEEP_ID] = self._sweep_id
            wandb_sdk.wandb_setup._setup(_reset=True)

            wandb.termlog("Agent Starting Run: {} with config:".format(run_id))
            for k, v in job.config.items():
                wandb.termlog("\t{}: {}".format(k, v["value"]))

            self._function()
            wandb.finish()
        except KeyboardInterrupt as ki:
            raise ki
        except Exception as e:
            wandb.finish(exit_code=1)
            if self._run_status[run_id] == RunStatus.RUNNING:
                self._run_status[run_id] = RunStatus.ERRORED
                self._exceptions[run_id] = e

    def run(self):
        logger.info(
            "Starting sweep agent: entity={}, project={}, count={}".format(
                self._entity, self._project, self._count))
        self._setup()
        # self._main_thread = threading.Thread(target=self._run_jobs_from_queue)
        self._heartbeat_thread = threading.Thread(target=self._heartbeat,
                                                  daemon=True)
        # self._main_thread.start()
        self._heartbeat_thread.start()
        # self._main_thread.join()
        self._run_jobs_from_queue()
Example #3
0
class Agent(object):

    FLAPPING_MAX_SECONDS = 60
    FLAPPING_MAX_FAILURES = 3

    def __init__(self,
                 sweep_id=None,
                 project=None,
                 entity=None,
                 function=None,
                 count=None):
        self._sweep_path = sweep_id
        self._sweep_id = None
        self._project = project
        self._entity = entity
        self._function = function
        self._count = count
        # glob_config = os.path.expanduser('~/.config/wandb/settings')
        # loc_config = 'wandb/settings'
        # files = (glob_config, loc_config)
        self._api = InternalApi()
        self._agent_id = None

    def _init(self):
        # These are not in constructor so that Agent instance can be rerun
        self._run_threads = {}
        self._queue = queue.Queue()
        self._stopped_runs = set()
        self._exit_flag = False
        self._errored_runs = {}

    def _register(self):
        logger.debug("Agent._register()")
        agent = self._api.register_agent(socket.gethostname(),
                                         sweep_id=self._sweep_id)
        self._agent_id = agent["id"]
        logger.debug("agent_id = {}".format(self._agent_id))

    def _setup(self):
        logger.debug("Agent._setup()")
        self._init()
        parts = dict(entity=self._entity,
                     project=self._project,
                     name=self._sweep_path)
        err = util.parse_sweep_id(parts)
        if err:
            wandb.termerror(err)
            return
        entity = parts.get("entity") or self._entity
        project = parts.get("project") or self._project
        sweep_id = parts.get("name") or self._sweep_id
        if sweep_id:
            os.environ[wandb.env.SWEEP_ID] = sweep_id
        if entity:
            wandb.env.set_entity(entity)
        if project:
            wandb.env.set_project(project)
        if sweep_id:
            self._sweep_id = sweep_id
        self._register()

    def _run_status(self):
        run_status = {}
        dead_runs = []
        for k, v in self._run_threads.items():
            if v.isAlive():
                run_status[k] = True
            else:
                dead_runs.append(k)
        # clean up dead runs
        for k in dead_runs:
            del self._run_threads[k]
        return run_status

    def _stop_run(self, run_id):
        logger.debug("Stopping run {}.".format(run_id))
        self._stopped_runs.add(run_id)
        thread = self._run_threads.get(run_id)
        if thread:
            _terminate_thread(thread)
            del self._run_threads[run_id]

    def _stop_all_runs(self):
        logger.debug("Stopping all runs.")
        for run in list(self._run_threads.keys()):
            self._stop_run(run)

    def _exit(self):
        self._stop_all_runs()
        self._exit_flag = True
        # _terminate_thread(self._main_thread)

    def _heartbeat(self):
        while True:
            if self._exit_flag:
                return
            # if not self._main_thread.isAlive():
            #     return
            commands = self._api.agent_heartbeat(self._agent_id, {},
                                                 self._run_status())
            if not commands:
                continue
            job = Job(commands[0])
            logger.debug("Job received: {}".format(job))
            if job.type == "run":
                self._queue.put(job)
            elif job.type == "stop":
                self._stop_run(job.run_id)
            elif job.type == "exit":
                self._exit()
                return
            time.sleep(5)

    def _run_jobs_from_queue(self):
        waiting = False
        count = 0
        while True:
            if self._exit_flag:
                return
            try:
                try:
                    job = self._queue.get(timeout=5)
                    if self._exit_flag:
                        logger.debug("Exiting main loop due to exit flag.")
                        wandb.termlog("Sweep Agent: Exiting.")
                        return
                except queue.Empty:
                    if not waiting:
                        logger.debug("Paused.")
                        wandb.termlog("Sweep Agent: Waiting for job.")
                        waiting = True
                    time.sleep(5)
                    if self._exit_flag:
                        logger.debug("Exiting main loop due to exit flag.")
                        wandb.termlog("Sweep Agent: Exiting.")
                        return
                    continue
                if waiting:
                    logger.debug("Resumed.")
                    wandb.termlog("Job received.")
                    waiting = False
                count += 1
                run_id = job.run_id
                logger.debug("Spawning new thread for run {}.".format(run_id))
                thread = threading.Thread(target=self._run_job, args=(job, ))
                self._run_threads[run_id] = thread
                thread.start()
                thread.join()
                logger.debug("Thread joined for run {}.".format(run_id))
                exc = self._errored_runs.get(run_id)
                if exc:
                    logger.error("Run {} errored: {}".format(
                        run_id, repr(exc)))
                    wandb.termerror("Run {} errored: {}".format(
                        run_id, repr(exc)))
                    if os.getenv(wandb.env.AGENT_DISABLE_FLAPPING) == "true":
                        self._exit_flag = True
                        return
                    elif (time.time() - wandb.START_TIME <
                          self.FLAPPING_MAX_SECONDS) and (
                              len(self._errored_runs) >=
                              self.FLAPPING_MAX_FAILURES):
                        msg = "Detected {} failed runs in the first {} seconds, killing sweep.".format(
                            self.FLAPPING_MAX_FAILURES,
                            self.FLAPPING_MAX_SECONDS)
                        logger.error(msg)
                        wandb.termerror(msg)
                        wandb.termlog(
                            "To disable this check set WANDB_AGENT_DISABLE_FLAPPING=true"
                        )
                        self._exit_flag = True
                        return
                del self._run_threads[job.run_id]
                if self._count and self._count == count:
                    logger.debug(
                        "Exiting main loop because max count reached.")
                    self._exit_flag = True
                    return
            except KeyboardInterrupt:
                logger.debug("Ctrl + C detected. Stopping sweep.")
                wandb.termlog("Ctrl + C detected. Stopping sweep.")
                self._exit()
                return
            except Exception as e:
                if self._exit_flag:
                    logger.debug("Exiting main loop due to exit flag.")
                    wandb.termlog("Sweep Agent: Killed.")
                    return
                else:
                    raise e

    def _run_job(self, job):
        try:
            run_id = job.run_id

            config_file = os.path.join("wandb", "sweep-" + self._sweep_id,
                                       "config-" + run_id + ".yaml")
            config_util.save_config_file_from_dict(config_file, job.config)
            os.environ[wandb.env.RUN_ID] = run_id
            os.environ[wandb.env.CONFIG_PATHS] = config_file
            os.environ[wandb.env.SWEEP_ID] = self._sweep_id
            wandb_sdk.wandb_setup._setup(_reset=True)

            wandb.termlog("Agent Starting Run: {} with config:".format(run_id))
            for k, v in job.config.items():
                wandb.termlog("\t{}: {}".format(k, v["value"]))

            self._function()
            wandb.finish()
        except KeyboardInterrupt as ki:
            raise ki
        except Exception as e:
            wandb.finish(exit_code=1)
            if run_id in self._stopped_runs:
                self._stopped_runs.remove(run_id)
                # wandb.termlog("Stopping run: " + str(run_id))
            else:
                self._errored_runs[run_id] = e

    def run(self):
        logger.info(
            "Starting sweep agent: entity={}, project={}, count={}".format(
                self._entity, self._project, self._count))
        self._setup()
        # self._main_thread = threading.Thread(target=self._run_jobs_from_queue)
        self._heartbeat_thread = threading.Thread(target=self._heartbeat,
                                                  daemon=True)
        # self._main_thread.start()
        self._heartbeat_thread.start()
        # self._main_thread.join()
        self._run_jobs_from_queue()