コード例 #1
0
    def statusUpdate(self, driver: MesosSchedulerDriver, update: Dict):
        super().statusUpdate(driver, update)

        if update["state"] not in LIVE_TASK_STATES:
            self.finished_countdown -= 1

        # Stop if task ran and finished
        if self.need_to_stop():
            driver.stop()
コード例 #2
0
    def tasks_and_state_for_offer(
            self, driver: MesosSchedulerDriver, offer,
            state: ConstraintState) -> Tuple[List[TaskInfo], ConstraintState]:
        # In dry run satisfy exit-conditions after we got the offer
        if self.dry_run or self.need_to_stop():
            if self.dry_run:
                tasks, _ = super().tasks_and_state_for_offer(
                    driver, offer, state)
                paasta_print("Would have launched: ", tasks)
            driver.stop()
            return [], state

        return super().tasks_and_state_for_offer(driver, offer, state)
コード例 #3
0
ファイル: schedule.py プロジェクト: posens/dpark
class MesosScheduler(DAGScheduler):
    def __init__(self, master, options, webui_url=None):
        DAGScheduler.__init__(self)
        self.master = master
        self.cpus = options.cpus
        self.mem = options.mem
        self.task_per_node = options.parallel or 8
        self.group = options.group
        self.logLevel = options.logLevel
        self.options = options
        self.role = options.role
        self.color = options.color
        self.webui_url = webui_url
        self.started = False
        self.last_finish_time = 0
        self.isRegistered = False
        self.executor = None
        self.driver = None
        self.out_logger = LogReceiver(sys.stdout)
        self.err_logger = LogReceiver(sys.stderr)
        self.lock = threading.RLock()
        self.task_host_manager = TaskHostManager()
        self.init_tasksets()

    def init_tasksets(self):
        self.active_tasksets = {}
        self.ttid_to_agent_id = {}
        self.agent_id_to_ttids = {}

    def clear(self):
        DAGScheduler.clear(self)
        self.init_tasksets()

    def processHeartBeat(self):
        # no need in dpark now, just for compatibility with pymesos
        pass

    def start(self):
        self.out_logger.start()
        self.err_logger.start()

    def start_driver(self):
        name = '[dpark] ' + \
               os.path.abspath(sys.argv[0]) + ' ' + ' '.join(sys.argv[1:])
        if len(name) > 256:
            name = name[:256] + '...'
        framework = Dict()
        framework.user = getuser()
        if framework.user == 'root':
            raise Exception('dpark is not allowed to run as \'root\'')
        framework.name = name
        if self.role:
            framework.role = self.role
        framework.hostname = socket.gethostname()
        if self.webui_url:
            framework.webui_url = self.webui_url

        self.driver = MesosSchedulerDriver(self,
                                           framework,
                                           self.master,
                                           use_addict=True)
        self.driver.start()
        logger.debug('Mesos Scheudler driver started')

        self.started = True
        self.last_finish_time = time.time()

        def check():
            while self.started:
                with self.lock:
                    now = time.time()
                    if (not self.active_tasksets
                            and now - self.last_finish_time > MAX_IDLE_TIME):
                        logger.info(
                            'stop mesos scheduler after %d seconds idle',
                            now - self.last_finish_time)
                        self.stop()
                        break

                    for taskset in self.active_tasksets.values():
                        if taskset.check_task_timeout():
                            self.requestMoreResources()
                time.sleep(1)

        spawn(check)

    @safe
    def registered(self, driver, frameworkId, masterInfo):
        self.isRegistered = True
        self.frameworkId = frameworkId.value
        logger.debug('connect to master %s:%s, registered as %s',
                     masterInfo.hostname, masterInfo.port, frameworkId.value)
        self.executor = self.getExecutorInfo(str(frameworkId.value))
        from dpark.utils.log import add_loghub
        self.loghub_dir = add_loghub(self.frameworkId)

    @safe
    def reregistered(self, driver, masterInfo):
        logger.warning('re-connect to mesos master %s:%s', masterInfo.hostname,
                       masterInfo.port)

    @safe
    def disconnected(self, driver):
        logger.debug('framework is disconnected')

    def _get_container_image(self):
        return self.options.image

    @safe
    def getExecutorInfo(self, framework_id):
        info = Dict()
        info.framework_id.value = framework_id
        info.command.value = '%s %s' % (
            sys.executable,
            os.path.abspath(
                os.path.join(os.path.dirname(__file__), 'executor.py')))
        info.executor_id.value = env.get('DPARK_ID', 'default')
        info.command.environment.variables = variables = []

        v = Dict()
        variables.append(v)
        v.name = 'UID'
        v.value = str(os.getuid())

        v = Dict()
        variables.append(v)
        v.name = 'GID'
        v.value = str(os.getgid())

        container_image = self._get_container_image()
        if container_image:
            info.container.type = 'DOCKER'
            info.container.docker.image = container_image
            info.container.docker.parameters = parameters = []
            p = Dict()
            p.key = 'memory-swap'
            p.value = '-1'
            parameters.append(p)

            info.container.volumes = volumes = []
            for path in ['/etc/passwd', '/etc/group']:
                v = Dict()
                volumes.append(v)
                v.host_path = v.container_path = path
                v.mode = 'RO'

            for path in conf.MOOSEFS_MOUNT_POINTS:
                v = Dict()
                volumes.append(v)
                v.host_path = v.container_path = path
                v.mode = 'RW'

            for path in conf.DPARK_WORK_DIR.split(','):
                v = Dict()
                volumes.append(v)
                v.host_path = v.container_path = path
                v.mode = 'RW'

            def _mount_volume(_volumes, _host_path, _container_path, _mode):
                _v = Dict()
                _volumes.append(_v)
                _v.container_path = _container_path
                _v.mode = _mode
                if _host_path:
                    _v.host_path = _host_path

            if self.options.volumes:
                for volume in self.options.volumes.split(','):
                    fields = volume.split(':')
                    if len(fields) == 3:
                        host_path, container_path, mode = fields
                        mode = mode.upper()
                        assert mode in ('RO', 'RW')
                    elif len(fields) == 2:
                        host_path, container_path = fields
                        mode = 'RW'
                    elif len(fields) == 1:
                        container_path, = fields
                        host_path = ''
                        mode = 'RW'
                    else:
                        raise Exception('cannot parse volume %s', volume)
                    _mount_volume(volumes, host_path, container_path, mode)

        info.resources = resources = []

        mem = Dict()
        resources.append(mem)
        mem.name = 'mem'
        mem.type = 'SCALAR'
        mem.scalar.value = EXECUTOR_MEMORY

        cpus = Dict()
        resources.append(cpus)
        cpus.name = 'cpus'
        cpus.type = 'SCALAR'
        cpus.scalar.value = EXECUTOR_CPUS

        Script = os.path.realpath(sys.argv[0])
        info.name = Script

        info.data = encode_data(
            marshal.dumps((Script, os.getcwd(), sys.path, dict(os.environ),
                           self.task_per_node, self.out_logger.addr,
                           self.err_logger.addr, self.logLevel, self.color,
                           env.environ)))
        assert len(info.data) < (50 << 20), \
            'Info data too large: %s' % (len(info.data),)
        return info

    @safe
    def submitTasks(self, tasks):
        if not tasks:
            return

        rdd = tasks[0].rdd
        assert all(t.rdd is rdd for t in tasks)

        taskset = TaskSet(self, tasks, rdd.cpus or self.cpus, rdd.mem
                          or self.mem, rdd.gpus, self.task_host_manager)
        self.active_tasksets[taskset.id] = taskset
        stage_scope = ''
        try:
            from dpark.web.ui.views.rddopgraph import StageInfo
            stage_scope = StageInfo.idToRDDNode[
                tasks[0].rdd.id].scope.call_site
        except:
            pass
        stage = self.idToStage[tasks[0].stage_id]
        stage.num_try += 1
        logger.info(
            'Got taskset %s with %d tasks for stage: %d '
            'at scope[%s] and rdd:%s', taskset.id, len(tasks),
            tasks[0].stage_id, stage_scope, tasks[0].rdd)

        need_revive = self.started
        if not self.started:
            self.start_driver()
        while not self.isRegistered:
            self.lock.release()
            time.sleep(0.01)
            self.lock.acquire()

        if need_revive:
            self.requestMoreResources()

    def requestMoreResources(self):
        logger.debug('reviveOffers')
        self.driver.reviveOffers()

    @safe
    def resourceOffers(self, driver, offers):
        rf = Dict()
        if not self.active_tasksets:
            driver.suppressOffers()
            rf.refuse_seconds = 60 * 5
            for o in offers:
                driver.declineOffer(o.id, rf)
            return

        start = time.time()
        filter_offer = []
        for o in offers:
            try:
                if conf.ban(o.hostname):
                    logger.debug("skip offer on banned node: %s", o.hostname)
                    continue
            except:
                logger.exception("bad ban() func in dpark.conf")

            group = (self.getAttribute(o.attributes, 'group') or 'None')
            if (self.group
                    or group.startswith('_')) and group not in self.group:
                driver.declineOffer(o.id,
                                    filters=Dict(refuse_seconds=0xFFFFFFFF))
                continue
            if self.task_host_manager.is_unhealthy_host(o.hostname):
                logger.warning('the host %s is unhealthy so skip it',
                               o.hostname)
                driver.declineOffer(o.id, filters=Dict(refuse_seconds=1800))
                continue
            self.task_host_manager.register_host(o.hostname)
            filter_offer.append(o)
        offers = filter_offer
        cpus = [self.getResource(o.resources, 'cpus') for o in offers]
        gpus = [self.getResource(o.resources, 'gpus') for o in offers]
        mems = [
            self.getResource(o.resources, 'mem') -
            (o.agent_id.value not in self.agent_id_to_ttids and EXECUTOR_MEMORY
             or 0) for o in offers
        ]
        # logger.debug('get %d offers (%s cpus, %s mem, %s gpus), %d tasksets',
        #             len(offers), sum(cpus), sum(mems), sum(gpus), len(self.active_tasksets))

        tasks = {}
        for taskset in self.active_tasksets.values():
            while True:
                host_offers = {}
                for i, o in enumerate(offers):
                    if self.agent_id_to_ttids.get(o.agent_id.value,
                                                  0) >= self.task_per_node:
                        logger.debug('the task limit exceeded at host %s',
                                     o.hostname)
                        continue
                    if (mems[i] < self.mem + EXECUTOR_MEMORY
                            or cpus[i] < self.cpus + EXECUTOR_CPUS):
                        continue
                    host_offers[o.hostname] = (i, o)
                assigned_list = taskset.taskOffer(host_offers, cpus, mems,
                                                  gpus)
                if not assigned_list:
                    break
                for i, o, t in assigned_list:
                    task = self.createTask(o, t)
                    tasks.setdefault(o.id.value, []).append(task)
                    logger.debug('dispatch %s into %s', t, o.hostname)
                    ttid = task.task_id.value
                    agent_id = o.agent_id.value
                    taskset.ttids.add(ttid)
                    self.ttid_to_agent_id[ttid] = agent_id
                    self.agent_id_to_ttids[
                        agent_id] = self.agent_id_to_ttids.get(agent_id, 0) + 1
                    cpus[i] -= min(cpus[i], t.cpus)
                    mems[i] -= t.mem
                    gpus[i] -= t.gpus

        used = time.time() - start
        if used > 10:
            logger.error('use too much time in resourceOffers: %.2fs', used)

        for o in offers:
            if o.id.value in tasks:
                driver.launchTasks(o.id, tasks[o.id.value])
            else:
                driver.declineOffer(o.id)

        # logger.debug('reply with %d tasks, %s cpus %s mem %s gpus left',
        #            sum(len(ts) for ts in tasks.values()),
        #             sum(cpus), sum(mems), sum(gpus))

    @safe
    def offerRescinded(self, driver, offer_id):
        logger.debug('rescinded offer: %s', offer_id)
        if self.active_tasksets:
            self.requestMoreResources()

    def getResource(self, res, name):
        for r in res:
            if r.name == name:
                return r.scalar.value
        return 0.0

    def getAttribute(self, attrs, name):
        for r in attrs:
            if r.name == name:
                return r.text.value

    def createTask(self, o, t):
        task = Dict()
        tid = t.try_id
        task.name = 'task %s' % tid
        task.task_id.value = tid
        task.agent_id.value = o.agent_id.value
        task.data = encode_data(compress(cPickle.dumps((t, tid), -1)))
        task.executor = self.executor
        if len(task.data) > 1000 * 1024:
            logger.warning('task too large: %s %d', t, len(task.data))

        assert len(task.data) < (50 << 20), \
            'Task data too large: %s' % (len(task.data),)

        resources = task.resources = []

        cpu = Dict()
        resources.append(cpu)
        cpu.name = 'cpus'
        cpu.type = 'SCALAR'
        cpu.scalar.value = t.cpus

        mem = Dict()
        resources.append(mem)
        mem.name = 'mem'
        mem.type = 'SCALAR'
        mem.scalar.value = t.mem

        cpu = Dict()
        resources.append(cpu)
        cpu.name = 'gpus'
        cpu.type = 'SCALAR'
        cpu.scalar.value = t.gpus

        return task

    @safe
    def statusUpdate(self, driver, status):
        def plot_progresses():
            if self.color:
                total = len(self.active_tasksets)
                logger.info('\x1b[2K\x1b[J\x1b[1A')
                for i, taskset_id in enumerate(self.active_tasksets):
                    if i == total - 1:
                        ending = '\x1b[%sA' % total
                    else:
                        ending = ''

                    tasksets = self.active_tasksets[taskset_id]
                    tasksets.progress(ending)

        mesos_task_id = status.task_id.value
        state = status.state
        reason = status.get('message')  # set by mesos
        data = status.get('data')

        logger.debug('status update: %s %s', mesos_task_id, state)

        ttid = TTID(mesos_task_id)

        taskset = self.active_tasksets.get(ttid.taskset_id)

        if taskset is None:
            if state == 'TASK_RUNNING':
                logger.debug('kill task %s as its taskset has gone',
                             mesos_task_id)
                self.driver.killTask(Dict(value=mesos_task_id))
            else:
                logger.debug('ignore task %s as its taskset has gone',
                             mesos_task_id)
            return

        if state == 'TASK_RUNNING':
            taskset.statusUpdate(ttid.task_id, ttid.task_try, state)
            if taskset.tasksFinished == 0:
                plot_progresses()
        else:
            if mesos_task_id not in taskset.ttids:
                logger.debug(
                    'ignore task %s as it has finished or failed, new msg: %s',
                    mesos_task_id, (state, reason))
            else:
                taskset.ttids.remove(mesos_task_id)
                if mesos_task_id in self.ttid_to_agent_id:
                    agent_id = self.ttid_to_agent_id[mesos_task_id]
                    if agent_id in self.agent_id_to_ttids:
                        self.agent_id_to_ttids[agent_id] -= 1
                    del self.ttid_to_agent_id[mesos_task_id]

                if state in ('TASK_FINISHED', 'TASK_FAILED') and data:
                    try:
                        reason, result, accUpdate, task_stats = cPickle.loads(
                            decode_data(data))
                        if result:
                            flag, data = result
                            if flag >= 2:
                                try:
                                    data = urllib.request.urlopen(data).read()
                                except IOError:
                                    # try again
                                    data = urllib.request.urlopen(data).read()
                                flag -= 2
                            data = decompress(data)
                            if flag == 0:
                                result = marshal.loads(data)
                            else:
                                result = cPickle.loads(data)
                        taskset.statusUpdate(ttid.task_id, ttid.task_try,
                                             state, reason, result, accUpdate,
                                             task_stats)
                        if state == 'TASK_FINISHED':
                            plot_progresses()
                    except Exception as e:
                        logger.warning(
                            'error when cPickle.loads(): %s, data:%s', e,
                            len(data))
                        state = 'TASK_FAILED'
                        taskset.statusUpdate(ttid.task_id, ttid.task_try,
                                             state, 'load failed: %s' % e)
                else:
                    # killed, lost
                    taskset.statusUpdate(ttid.task_id, ttid.task_try, state,
                                         reason or data)

    @safe
    def tasksetFinished(self, taskset):
        logger.debug('taskset %s finished', taskset.id)
        if taskset.id in self.active_tasksets:
            self.last_finish_time = time.time()
            for mesos_task_id in taskset.ttids:
                self.driver.killTask(Dict(value=mesos_task_id))
            del self.active_tasksets[taskset.id]

            if not self.active_tasksets:
                self.agent_id_to_ttids.clear()

    @safe
    def error(self, driver, message):
        logger.error('Mesos error message: %s', message)
        raise RuntimeError(message)

    # @safe
    def stop(self):
        if not self.started:
            return
        logger.debug('stop scheduler')
        self.started = False
        self.isRegistered = False
        self.driver.stop(False)
        self.driver.join()
        self.driver = None

        self.out_logger.stop()
        self.err_logger.stop()

    def defaultParallelism(self):
        return 16

    def frameworkMessage(self, driver, executor_id, agent_id, data):
        logger.warning('[agent %s] %s', agent_id.value, data)

    def executorLost(self, driver, executor_id, agent_id, status):
        logger.warning('executor at %s %s lost: %s', agent_id.value,
                       executor_id.value, status)
        self.agent_id_to_ttids.pop(agent_id.value, None)

    def slaveLost(self, driver, agent_id):
        logger.warning('agent %s lost', agent_id.value)
        self.agent_id_to_ttids.pop(agent_id.value, None)

    def killTask(self, task_id, num_try):
        tid = Dict()
        tid.value = TTID.make_ttid(task_id, num_try)
        self.driver.killTask(tid)
コード例 #4
0
ファイル: scheduler.py プロジェクト: vshlapakov/pymesos
class ProcScheduler(Scheduler):
    def __init__(self):
        self.framework_id = None
        self.framework = self._init_framework()
        self.executor = None
        self.master = str(CONFIG.get("master", os.environ["MESOS_MASTER"]))
        self.driver = MesosSchedulerDriver(self, self.framework, self.master)
        self.procs_pending = {}
        self.procs_launched = {}
        self.slave_to_proc = {}
        self._lock = RLock()

    def _init_framework(self):
        framework = mesos_pb2.FrameworkInfo()
        framework.user = getpass.getuser()
        framework.name = repr(self)
        framework.hostname = socket.gethostname()
        return framework

    def _init_executor(self):
        executor = mesos_pb2.ExecutorInfo()
        executor.executor_id.value = "default"
        executor.command.value = "%s -m %s.executor" % (sys.executable, __package__)

        mem = executor.resources.add()
        mem.name = "mem"
        mem.type = mesos_pb2.Value.SCALAR
        mem.scalar.value = MIN_MEMORY

        cpus = executor.resources.add()
        cpus.name = "cpus"
        cpus.type = mesos_pb2.Value.SCALAR
        cpus.scalar.value = MIN_CPUS

        if "PYTHONPATH" in os.environ:
            var = executor.command.environment.variables.add()
            var.name = "PYTHONPATH"
            var.value = os.environ["PYTHONPATH"]

        executor.framework_id.value = str(self.framework_id.value)
        return executor

    def _init_task(self, proc, offer):
        task = mesos_pb2.TaskInfo()
        task.task_id.value = str(proc.id)
        task.slave_id.value = offer.slave_id.value
        task.name = repr(proc)
        task.executor.MergeFrom(self.executor)
        task.data = pickle.dumps(proc.params)

        cpus = task.resources.add()
        cpus.name = "cpus"
        cpus.type = mesos_pb2.Value.SCALAR
        cpus.scalar.value = proc.cpus

        mem = task.resources.add()
        mem.name = "mem"
        mem.type = mesos_pb2.Value.SCALAR
        mem.scalar.value = proc.mem

        return task

    def _filters(self, seconds):
        f = mesos_pb2.Filters()
        f.refuse_seconds = seconds
        return f

    def __repr__(self):
        return "%s[%s]: %s" % (self.__class__.__name__, os.getpid(), " ".join(sys.argv))

    def registered(self, driver, framework_id, master_info):
        with self._lock:
            logger.info("Framework registered with id=%s, master=%s" % (framework_id, master_info))
            self.framework_id = framework_id
            self.executor = self._init_executor()

    def resourceOffers(self, driver, offers):
        def get_resources(offer):
            cpus, mem = 0.0, 0.0
            for r in offer.resources:
                if r.name == "cpus":
                    cpus = float(r.scalar.value)
                elif r.name == "mem":
                    mem = float(r.scalar.value)
            return cpus, mem

        with self._lock:
            random.shuffle(offers)
            for offer in offers:
                if not self.procs_pending:
                    logger.debug("Reject offers forever for no pending procs, " "offers=%s" % (offers,))
                    driver.launchTasks(offer.id, [], self._filters(FOREVER))
                    continue

                cpus, mem = get_resources(offer)
                tasks = []
                for proc in self.procs_pending.values():
                    if cpus >= proc.cpus and mem >= proc.mem:
                        tasks.append(self._init_task(proc, offer))
                        del self.procs_pending[proc.id]
                        self.procs_launched[proc.id] = proc
                        cpus -= proc.cpus
                        mem -= proc.mem

                seconds = 5 + random.random() * 5
                driver.launchTasks(offer.id, tasks, self._filters(seconds))
                if tasks:
                    logger.info(
                        "Accept offer for procs, offer=%s, "
                        "procs=%s, filter_time=%s" % (offer, [int(t.task_id.value) for t in tasks], seconds)
                    )
                else:
                    logger.info("Retry offer for procs later, offer=%s, " "filter_time=%s" % (offer, seconds))

    def _call_finished(self, proc_id, success, message, data, slave_id=None):
        with self._lock:
            proc = self.procs_launched.pop(proc_id)
            if slave_id is not None:
                if slave_id in self.slave_to_proc:
                    self.slave_to_proc[slave_id].remove(proc_id)
            else:
                for slave_id, procs in self.slave_to_proc.iteritems():
                    if proc_id in procs:
                        procs.remove(proc_id)

            proc._finished(success, message, data)

    def statusUpdate(self, driver, update):
        with self._lock:
            proc_id = int(update.task_id.value)
            logger.info("Status update for proc, id=%s, state=%s" % (proc_id, update.state))
            if update.state == mesos_pb2.TASK_RUNNING:
                if update.slave_id.value in self.slave_to_proc:
                    self.slave_to_proc[update.slave_id.value].add(proc_id)
                else:
                    self.slave_to_proc[update.slave_id.value] = set([proc_id])

                proc = self.procs_launched[proc_id]
                proc._started()

            elif update.state >= mesos_pb2.TASK_FINISHED:
                slave_id = update.slave_id.value
                success = update.state == mesos_pb2.TASK_FINISHED
                message = update.message
                data = update.data and pickle.loads(update.data)
                self._call_finished(proc_id, success, message, data, slave_id)
                driver.reviveOffers()

    def offerRescinded(self, driver, offer_id):
        with self._lock:
            if self.procs_pending:
                logger.info("Revive offers for pending procs")
                driver.reviveOffers()

    def slaveLost(self, driver, slave_id):
        with self._lock:
            for proc_id in self.slave_to_proc.pop(slave_id, []):
                self._call_finished(proc_id, False, "Slave lost", None, slave_id)

    def error(self, driver, message):
        with self._lock:
            for proc in self.procs_pending.values():
                self._call_finished(proc.id, False, "Stopped", None)

            for proc in self.procs_launched.values():
                self._call_finished(proc.id, False, "Stopped", None)

        self.stop()

    def start(self):
        self.driver.start()

    def stop(self):
        assert not self.driver.aborted
        self.driver.stop()

    def submit(self, proc):
        if self.driver.aborted:
            raise RuntimeError("driver already aborted")

        with self._lock:
            if proc.id not in self.procs_pending:
                logger.info("Try submit proc, id=%s", (proc.id,))
                self.procs_pending[proc.id] = proc
                if len(self.procs_pending) == 1:
                    logger.info("Revive offers for pending procs")
                    self.driver.reviveOffers()
            else:
                raise ValueError("Proc with same id already submitted")

    def cancel(self, proc):
        if self.driver.aborted:
            raise RuntimeError("driver already aborted")

        with self._lock:
            if proc.id in self.procs_pending:
                del self.procs_pending[proc.id]
            elif proc.id in self.procs_launched:
                del self.procs_launched[proc.id]
                self.driver.killTask(mesos_pb2.TaskID(value=str(proc.id)))

            for slave_id, procs in self.slave_to_proc.items():
                procs.pop(proc.id)
                if not procs:
                    del self.slave_to_proc[slave_id]

    def send_data(self, pid, type, data):
        if self.driver.aborted:
            raise RuntimeError("driver already aborted")

        msg = pickle.dumps((pid, type, data))
        for slave_id, procs in self.slave_to_proc.iteritems():
            if pid in procs:
                self.driver.sendFrameworkMessage(self.executor.executor_id, mesos_pb2.SlaveID(value=slave_id), msg)
                return

        raise RuntimeError("Cannot find slave for pid %s" % (pid,))
コード例 #5
0
ファイル: scheduler.py プロジェクト: pandasasa/tfmesos
class TFMesosScheduler(Scheduler):
    def __init__(self,
                 task_spec,
                 master=None,
                 name=None,
                 quiet=False,
                 volumes={}):
        self.started = False
        self.master = master or os.environ['MESOS_MASTER']
        self.name = name or '[tensorflow] %s %s' % (os.path.abspath(
            sys.argv[0]), ' '.join(sys.argv[1:]))
        self.task_spec = task_spec
        self.tasks = []
        for job in task_spec:
            for task_index in xrange(job.num):
                mesos_task_id = len(self.tasks)
                self.tasks.append(
                    Task(
                        mesos_task_id,
                        job.name,
                        task_index,
                        cpus=job.cpus,
                        mem=job.mem,
                        volumes=volumes,
                    ))
        if not quiet:
            global logger
            setup_logger(logger)

    def resourceOffers(self, driver, offers):
        '''
        Offer resources and launch tasks
        '''

        for offer in offers:
            if all(task.offered for task in self.tasks):
                driver.declineOffer(offer.id,
                                    mesos_pb2.Filters(refuse_seconds=FOREVER))
                continue

            offered_cpus = offered_mem = 0.0
            offered_tasks = []

            for resource in offer.resources:
                if resource.name == "cpus":
                    offered_cpus = resource.scalar.value
                elif resource.name == "mem":
                    offered_mem = resource.scalar.value

            for task in self.tasks:
                if task.offered:
                    continue

                if not (task.cpus <= offered_cpus and task.mem <= offered_mem):

                    continue

                offered_cpus -= task.cpus
                offered_mem -= task.mem
                task.offered = True
                offered_tasks.append(task.to_task_info(offer, self.addr))

            driver.launchTasks(offer.id, offered_tasks, mesos_pb2.Filters())

    def _start_tf_cluster(self):
        cluster_def = {}

        targets = {}
        for task in self.tasks:
            target_name = '/job:%s/task:%s' % (task.job_name, task.task_index)
            grpc_addr = 'grpc://%s' % task.addr
            targets[target_name] = grpc_addr
            cluster_def.setdefault(task.job_name, []).append(task.addr)

        for task in self.tasks:
            response = {
                "job_name": task.job_name,
                "task_index": task.task_index,
                "cpus": task.cpus,
                "mem": task.mem,
                "cluster_def": cluster_def,
            }
            send(task.connection, response)
            assert recv(task.connection) == "ok"
            logger.info("Device /job:%s/task:%s activated @ grpc://%s " %
                        (task.job_name, task.task_index, task.addr))
            task.connection.close()
        return targets

    def start(self):
        def readable(fd):
            return bool(select.select([fd], [], [], 0.1)[0])

        lfd = socket.socket()
        try:
            lfd.bind(('', 0))
            self.addr = '%s:%s' % (socket.gethostname(), lfd.getsockname()[1])
            lfd.listen(10)
            framework = mesos_pb2.FrameworkInfo()
            framework.user = getpass.getuser()
            framework.name = self.name
            framework.hostname = socket.gethostname()
            self.driver = MesosSchedulerDriver(self, framework, self.master)
            self.driver.start()
            while any((not task.initalized for task in self.tasks)):
                if readable(lfd):
                    c, _ = lfd.accept()
                    if readable(c):
                        mesos_task_id, addr = recv(c)
                        assert isinstance(mesos_task_id, int)
                        task = self.tasks[mesos_task_id]
                        task.addr = addr
                        task.connection = c
                        task.initalized = True
                    else:
                        c.close()
            return self._start_tf_cluster()
        except Exception:
            self.stop()
            raise
        finally:
            lfd.close()

    def registered(self, driver, framework_id, master_info):
        logger.info(
            "Tensorflow cluster registered. "
            "( http://%s:%s/#/frameworks/%s )" %
            (master_info.hostname, master_info.port, framework_id.value))

    def statusUpdate(self, driver, update):
        mesos_task_id = int(update.task_id.value)
        if update.state != mesos_pb2.TASK_RUNNING:
            task = self.tasks[mesos_task_id]
            if self.started:
                logger.error("Task failed: %s" % task)
                _raise(RuntimeError('Task %s failed!' % id))
            else:
                logger.warn("Task failed: %s" % task)
                task.connection.close()
                driver.reviveOffers()

    def slaveLost(self, driver, slaveId):
        if self.started:
            logger.error("Slave %s lost:" % slaveId.value)
            _raise(RuntimeError('Slave %s lost' % slaveId))

    def executorLost(self, driver, executorId, slaveId, status):
        if self.started:
            logger.error("Executor %s lost:" % executorId.value)
            _raise(RuntimeError('Executor %s@%s lost' % (executorId, slaveId)))

    def error(self, driver, message):
        logger.error("Mesos error: %s" % message)
        _raise(RuntimeError('Error ' + message))

    def stop(self):
        logger.debug("exit")

        if hasattr(self, "tasks"):
            for task in getattr(self, "tasks", []):
                task.connection.close()
            del self.tasks

        if hasattr(self, "driver"):
            self.driver.stop()
            del self.driver
コード例 #6
0
ファイル: schedule.py プロジェクト: windreamer/dpark
class MesosScheduler(DAGScheduler):

    def __init__(self, master, options):
        DAGScheduler.__init__(self)
        self.master = master
        self.use_self_as_exec = options.self
        self.cpus = options.cpus
        self.mem = options.mem
        self.task_per_node = options.parallel or multiprocessing.cpu_count()
        self.group = options.group
        self.logLevel = options.logLevel
        self.options = options
        self.started = False
        self.last_finish_time = 0
        self.isRegistered = False
        self.executor = None
        self.driver = None
        self.out_logger = None
        self.err_logger = None
        self.lock = threading.RLock()
        self.init_job()

    def init_job(self):
        self.activeJobs = {}
        self.activeJobsQueue = []
        self.taskIdToJobId = {}
        self.taskIdToAgentId = {}
        self.jobTasks = {}
        self.agentTasks = {}

    def clear(self):
        DAGScheduler.clear(self)
        self.init_job()

    def start(self):
        if not self.out_logger:
            self.out_logger = self.start_logger(sys.stdout)
        if not self.err_logger:
            self.err_logger = self.start_logger(sys.stderr)

    def start_driver(self):
        name = '[dpark] ' + \
            os.path.abspath(sys.argv[0]) + ' ' + ' '.join(sys.argv[1:])
        if len(name) > 256:
            name = name[:256] + '...'
        framework = Dict()
        framework.user = getuser()
        if framework.user == 'root':
            raise Exception('dpark is not allowed to run as \'root\'')
        framework.name = name
        framework.hostname = socket.gethostname()
        framework.webui_url = self.options.webui_url

        self.driver = MesosSchedulerDriver(
            self, framework, self.master, use_addict=True
        )
        self.driver.start()
        logger.debug('Mesos Scheudler driver started')

        self.started = True
        self.last_finish_time = time.time()

        def check():
            while self.started:
                now = time.time()
                if (not self.activeJobs and
                        now - self.last_finish_time > MAX_IDLE_TIME):
                    logger.info('stop mesos scheduler after %d seconds idle',
                                now - self.last_finish_time)
                    self.stop()
                    break
                time.sleep(1)

        spawn(check)

    def start_logger(self, output):
        sock = env.ctx.socket(zmq.PULL)
        port = sock.bind_to_random_port('tcp://0.0.0.0')

        def collect_log():
            while not self._shutdown:
                if sock.poll(1000, zmq.POLLIN):
                    line = sock.recv()
                    output.write(line)

        spawn(collect_log)

        host = socket.gethostname()
        addr = 'tcp://%s:%d' % (host, port)
        logger.debug('log collecter start at %s', addr)
        return addr

    @safe
    def registered(self, driver, frameworkId, masterInfo):
        self.isRegistered = True
        logger.debug('connect to master %s:%s, registered as %s',
                     masterInfo.hostname, masterInfo.port, frameworkId.value)
        self.executor = self.getExecutorInfo(str(frameworkId.value))

    @safe
    def reregistered(self, driver, masterInfo):
        logger.warning('re-connect to mesos master %s:%s',
                       masterInfo.hostname, masterInfo.port)

    @safe
    def disconnected(self, driver):
        logger.debug('framework is disconnected')

    @safe
    def getExecutorInfo(self, framework_id):
        info = Dict()
        info.framework_id.value = framework_id

        if self.use_self_as_exec:
            info.command.value = os.path.abspath(sys.argv[0])
            info.executor_id.value = sys.argv[0]
        else:
            info.command.value = '%s %s' % (
                sys.executable,
                os.path.abspath(
                    os.path.join(
                        os.path.dirname(__file__),
                        'executor.py'))
            )
            info.executor_id.value = 'default'

        info.command.environment.variables = variables = []

        v = Dict()
        variables.append(v)
        v.name = 'UID'
        v.value = str(os.getuid())

        v = Dict()
        variables.append(v)
        v.name = 'GID'
        v.value = str(os.getgid())

        if self.options.image:
            info.container.type = 'DOCKER'
            info.container.docker.image = self.options.image

            info.container.volumes = volumes = []
            for path in ['/etc/passwd', '/etc/group']:
                v = Dict()
                volumes.append(v)
                v.host_path = v.container_path = path
                v.mode = 'RO'

            for path in conf.MOOSEFS_MOUNT_POINTS:
                v = Dict()
                volumes.append(v)
                v.host_path = v.container_path = path
                v.mode = 'RW'

            for path in conf.DPARK_WORK_DIR.split(','):
                v = Dict()
                volumes.append(v)
                v.host_path = v.container_path = path
                v.mode = 'RW'

            if self.options.volumes:
                for volume in self.options.volumes.split(','):
                    fields = volume.split(':')
                    if len(fields) == 3:
                        host_path, container_path, mode = fields
                        mode = mode.upper()
                        assert mode in ('RO', 'RW')
                    elif len(fields) == 2:
                        host_path, container_path = fields
                        mode = 'RW'
                    elif len(fields) == 1:
                        container_path, = fields
                        host_path = ''
                        mode = 'RW'
                    else:
                        raise Exception('cannot parse volume %s', volume)

                    mkdir_p(host_path)

                v = Dict()
                volumes.append(v)
                v.container_path = container_path
                v.mode = mode
                if host_path:
                    v.host_path = host_path

        info.resources = resources = []

        mem = Dict()
        resources.append(mem)
        mem.name = 'mem'
        mem.type = 'SCALAR'
        mem.scalar.value = EXECUTOR_MEMORY

        cpus = Dict()
        resources.append(cpus)
        cpus.name = 'cpus'
        cpus.type = 'SCALAR'
        cpus.scalar.value = EXECUTOR_CPUS

        Script = os.path.realpath(sys.argv[0])
        info.name = Script

        info.data = encode_data(marshal.dumps(
            (
                Script, os.getcwd(), sys.path, dict(os.environ),
                self.task_per_node, self.out_logger, self.err_logger,
                self.logLevel, env.environ
            )
        ))
        return info

    @safe
    def submitTasks(self, tasks):
        if not tasks:
            return

        job = SimpleJob(self, tasks, self.cpus, tasks[0].rdd.mem or self.mem)
        self.activeJobs[job.id] = job
        self.activeJobsQueue.append(job)
        self.jobTasks[job.id] = set()
        logger.info(
            'Got job %d with %d tasks: %s',
            job.id,
            len(tasks),
            tasks[0].rdd)

        need_revive = self.started
        if not self.started:
            self.start_driver()
        while not self.isRegistered:
            self.lock.release()
            time.sleep(0.01)
            self.lock.acquire()

        if need_revive:
            self.requestMoreResources()

    def requestMoreResources(self):
        logger.debug('reviveOffers')
        self.driver.reviveOffers()

    @safe
    def resourceOffers(self, driver, offers):
        rf = Dict()
        if not self.activeJobs:
            rf.refuse_seconds = 60 * 5
            for o in offers:
                driver.declineOffer(o.id, rf)
            return

        start = time.time()
        random.shuffle(offers)
        cpus = [self.getResource(o.resources, 'cpus') for o in offers]
        mems = [self.getResource(o.resources, 'mem')
                - (o.agent_id.value not in self.agentTasks
                    and EXECUTOR_MEMORY or 0)
                for o in offers]
        logger.debug('get %d offers (%s cpus, %s mem), %d jobs',
                     len(offers), sum(cpus), sum(mems), len(self.activeJobs))

        tasks = {}
        for job in self.activeJobsQueue:
            while True:
                launchedTask = False
                for i, o in enumerate(offers):
                    sid = o.agent_id.value
                    group = (
                        self.getAttribute(
                            o.attributes,
                            'group') or 'None')
                    if (self.group or group.startswith(
                            '_')) and group not in self.group:
                        continue
                    if self.agentTasks.get(sid, 0) >= self.task_per_node:
                        continue
                    if (mems[i] < self.mem + EXECUTOR_MEMORY
                            or cpus[i] < self.cpus + EXECUTOR_CPUS):
                        continue
                    t = job.slaveOffer(str(o.hostname), cpus[i], mems[i])
                    if not t:
                        continue
                    task = self.createTask(o, job, t)
                    tasks.setdefault(o.id.value, []).append(task)

                    logger.debug('dispatch %s into %s', t, o.hostname)
                    tid = task.task_id.value
                    self.jobTasks[job.id].add(tid)
                    self.taskIdToJobId[tid] = job.id
                    self.taskIdToAgentId[tid] = sid
                    self.agentTasks[sid] = self.agentTasks.get(sid, 0) + 1
                    cpus[i] -= min(cpus[i], t.cpus)
                    mems[i] -= t.mem
                    launchedTask = True

                if not launchedTask:
                    break

        used = time.time() - start
        if used > 10:
            logger.error('use too much time in resourceOffers: %.2fs', used)

        for o in offers:
            if o.id.value in tasks:
                driver.launchTasks(o.id, tasks[o.id.value])
            else:
                driver.declineOffer(o.id)

        logger.debug('reply with %d tasks, %s cpus %s mem left',
                     sum(len(ts) for ts in tasks.values()),
                     sum(cpus), sum(mems))

    @safe
    def offerRescinded(self, driver, offer_id):
        logger.debug('rescinded offer: %s', offer_id)
        if self.activeJobs:
            self.requestMoreResources()

    def getResource(self, res, name):
        for r in res:
            if r.name == name:
                return r.scalar.value
        return 0.0

    def getAttribute(self, attrs, name):
        for r in attrs:
            if r.name == name:
                return r.text.value

    def createTask(self, o, job, t):
        task = Dict()
        tid = '%s:%s:%s' % (job.id, t.id, t.tried)
        task.name = 'task %s' % tid
        task.task_id.value = tid
        task.agent_id.value = o.agent_id.value
        task.data = encode_data(
            compress(cPickle.dumps((t, t.tried), -1))
        )
        task.executor = self.executor
        if len(task.data) > 1000 * 1024:
            logger.warning('task too large: %s %d',
                           t, len(task.data))

        resources = task.resources = []

        cpu = Dict()
        resources.append(cpu)
        cpu.name = 'cpus'
        cpu.type = 'SCALAR'
        cpu.scalar.value = t.cpus

        mem = Dict()
        resources.append(mem)
        mem.name = 'mem'
        mem.type = 'SCALAR'
        mem.scalar.value = t.mem
        return task

    @safe
    def statusUpdate(self, driver, status):
        tid = status.task_id.value
        state = status.state
        logger.debug('status update: %s %s', tid, state)

        jid = self.taskIdToJobId.get(tid)
        _, task_id, tried = map(int, tid.split(':'))
        if state == 'TASK_RUNNING':
            if jid in self.activeJobs:
                job = self.activeJobs[jid]
                job.statusUpdate(task_id, tried, state)
            else:
                logger.debug('kill task %s as its job has gone', tid)
                self.driver.killTask(Dict(value=tid))

            return

        self.taskIdToJobId.pop(tid, None)
        if jid in self.jobTasks:
            self.jobTasks[jid].remove(tid)
        if tid in self.taskIdToAgentId:
            agent_id = self.taskIdToAgentId[tid]
            if agent_id in self.agentTasks:
                self.agentTasks[agent_id] -= 1
            del self.taskIdToAgentId[tid]

        if jid not in self.activeJobs:
            logger.debug('ignore task %s as its job has gone', tid)
            return

        job = self.activeJobs[jid]
        data = status.get('data')
        if state in ('TASK_FINISHED', 'TASK_FAILED') and data:
            try:
                reason, result, accUpdate = cPickle.loads(
                    decode_data(data))
                if result:
                    flag, data = result
                    if flag >= 2:
                        try:
                            data = urllib.urlopen(data).read()
                        except IOError:
                            # try again
                            data = urllib.urlopen(data).read()
                        flag -= 2
                    data = decompress(data)
                    if flag == 0:
                        result = marshal.loads(data)
                    else:
                        result = cPickle.loads(data)
            except Exception as e:
                logger.warning(
                    'error when cPickle.loads(): %s, data:%s', e, len(data))
                state = 'TASK_FAILED'
                return job.statusUpdate(
                    task_id, tried, 'TASK_FAILED', 'load failed: %s' % e)
            else:
                return job.statusUpdate(task_id, tried, state,
                                        reason, result, accUpdate)

        # killed, lost, load failed
        job.statusUpdate(task_id, tried, state, data)

    def jobFinished(self, job):
        logger.debug('job %s finished', job.id)
        if job.id in self.activeJobs:
            del self.activeJobs[job.id]
            self.activeJobsQueue.remove(job)
            for tid in self.jobTasks[job.id]:
                self.driver.killTask(Dict(value=tid))
            del self.jobTasks[job.id]
            self.last_finish_time = time.time()

            if not self.activeJobs:
                self.agentTasks.clear()

        for tid, jid in self.taskIdToJobId.iteritems():
            if jid not in self.activeJobs:
                logger.debug('kill task %s, because it is orphan', tid)
                self.driver.killTask(Dict(value=tid))

    @safe
    def check(self):
        for job in self.activeJobs.values():
            if job.check_task_timeout():
                self.requestMoreResources()

    @safe
    def error(self, driver, message):
        logger.warning('Mesos error message: %s', message)

    # @safe
    def stop(self):
        if not self.started:
            return
        logger.debug('stop scheduler')
        self.started = False
        self.isRegistered = False
        self.driver.stop(False)
        self.driver = None

    def defaultParallelism(self):
        return 16

    def frameworkMessage(self, driver, executor_id, agent_id, data):
        logger.warning('[agent %s] %s', agent_id.value, data)

    def executorLost(self, driver, executor_id, agent_id, status):
        logger.warning(
            'executor at %s %s lost: %s',
            agent_id.value,
            executor_id.value,
            status)
        self.agentTasks.pop(agent_id.value, None)

    def slaveLost(self, driver, agent_id):
        logger.warning('agent %s lost', agent_id.value)
        self.agentTasks.pop(agent_id.value, None)

    def killTask(self, job_id, task_id, tried):
        tid = Dict()
        tid.value = '%s:%s:%s' % (job_id, task_id, tried)
        self.driver.killTask(tid)
コード例 #7
0
ファイル: scheduler.py プロジェクト: douban/tfmesos
class TFMesosScheduler(Scheduler):
    MAX_FAILURE_COUNT = 3

    def __init__(self, task_spec, role=None, master=None, name=None,
                 quiet=False, volumes={}, containerizer_type=None,
                 force_pull_image=False, forward_addresses=None,
                 protocol='grpc', env={}, extra_config={}):
        self.started = False
        self.master = master or os.environ['MESOS_MASTER']
        self.name = name or '[tensorflow] %s %s' % (
            os.path.abspath(sys.argv[0]), ' '.join(sys.argv[1:]))
        self.task_spec = task_spec
        self.containerizer_type = containerizer_type
        self.force_pull_image = force_pull_image
        self.protocol = protocol
        self.extra_config = extra_config
        self.forward_addresses = forward_addresses
        self.role = role or '*'
        self.tasks = {}
        self.task_failure_count = {}
        self.job_finished = {}
        for job in task_spec:
            self.job_finished[job.name] = 0
            for task_index in range(job.start, job.num):
                mesos_task_id = str(uuid.uuid4())
                task = Task(
                    mesos_task_id,
                    job.name,
                    task_index,
                    cpus=job.cpus,
                    mem=job.mem,
                    gpus=job.gpus,
                    cmd=job.cmd,
                    volumes=volumes,
                    env=env
                )
                self.tasks[mesos_task_id] = task
                self.task_failure_count[self.decorated_task_index(task)] = 0

        if not quiet:
            global logger
            setup_logger(logger)

    def resourceOffers(self, driver, offers):
        '''
        Offer resources and launch tasks
        '''

        for offer in offers:
            if all(task.offered for id, task in iteritems(self.tasks)):
                self.driver.suppressOffers()
                driver.declineOffer(offer.id, Dict(refuse_seconds=FOREVER))
                continue

            offered_cpus = offered_mem = 0.0
            offered_gpus = []
            offered_tasks = []
            gpu_resource_type = None

            for resource in offer.resources:
                if resource.name == 'cpus':
                    offered_cpus = resource.scalar.value
                elif resource.name == 'mem':
                    offered_mem = resource.scalar.value
                elif resource.name == 'gpus':
                    if resource.type == 'SET':
                        offered_gpus = resource.set.item
                    else:
                        offered_gpus = list(range(int(resource.scalar.value)))

                    gpu_resource_type = resource.type

            for id, task in iteritems(self.tasks):
                if task.offered:
                    continue

                if not (task.cpus <= offered_cpus and
                        task.mem <= offered_mem and
                        task.gpus <= len(offered_gpus)):

                    continue

                offered_cpus -= task.cpus
                offered_mem -= task.mem
                gpus = int(math.ceil(task.gpus))
                gpu_uuids = offered_gpus[:gpus]
                offered_gpus = offered_gpus[gpus:]
                task.offered = True
                offered_tasks.append(
                    task.to_task_info(
                        offer, self.addr, gpu_uuids=gpu_uuids,
                        gpu_resource_type=gpu_resource_type,
                        containerizer_type=self.containerizer_type,
                        force_pull_image=self.force_pull_image
                    )
                )

            driver.launchTasks(offer.id, offered_tasks)

    @property
    def targets(self):
        targets = {}
        for id, task in iteritems(self.tasks):
            target_name = '/job:%s/task:%s' % (task.job_name, task.task_index)
            grpc_addr = 'grpc://%s' % task.addr
            targets[target_name] = grpc_addr
        return targets

    def _start_tf_cluster(self):
        cluster_def = {}

        tasks = sorted(self.tasks.values(), key=lambda task: task.task_index)
        for task in tasks:
            cluster_def.setdefault(task.job_name, []).append(task.addr)

        for id, task in iteritems(self.tasks):
            response = {
                'job_name': task.job_name,
                'task_index': task.task_index,
                'cpus': task.cpus,
                'mem': task.mem,
                'gpus': task.gpus,
                'cmd': task.cmd,
                'cwd': os.getcwd(),
                'cluster_def': cluster_def,
                'forward_addresses': self.forward_addresses,
                'extra_config': self.extra_config,
                'protocol': self.protocol
            }
            send(task.connection, response)
            assert recv(task.connection) == 'ok'
            logger.info(
                'Device /job:%s/task:%s activated @ grpc://%s ',
                task.job_name,
                task.task_index,
                task.addr

            )
            task.connection.close()

    def start(self):

        def readable(fd):
            return bool(select.select([fd], [], [], 0.1)[0])

        lfd = socket.socket()
        try:
            lfd.bind(('', 0))
            self.addr = '%s:%s' % (socket.gethostname(), lfd.getsockname()[1])
            lfd.listen(10)
            framework = Dict()
            framework.user = getpass.getuser()
            framework.name = self.name
            framework.hostname = socket.gethostname()
            framework.role = self.role

            self.driver = MesosSchedulerDriver(
                self, framework, self.master, use_addict=True
            )
            self.driver.start()
            task_start_count = 0
            while any((not task.initalized
                       for id, task in iteritems(self.tasks))):
                if readable(lfd):
                    c, _ = lfd.accept()
                    if readable(c):
                        mesos_task_id, addr = recv(c)
                        task = self.tasks[mesos_task_id]
                        task.addr = addr
                        task.connection = c
                        task.initalized = True
                        task_start_count += 1
                        logger.info('Task %s with mesos_task_id %s has '
                                    'registered',
                                    '{}:{}'.format(task.job_name,
                                                   task.task_index),
                                    mesos_task_id)
                        logger.info('Out of %d tasks '
                                    '%d tasks have been registered',
                                    len(self.tasks), task_start_count)
                    else:
                        c.close()

            self.started = True
            self._start_tf_cluster()
        except Exception:
            self.stop()
            raise
        finally:
            lfd.close()

    def registered(self, driver, framework_id, master_info):
        logger.info(
            'Tensorflow cluster registered. '
            '( http://%s:%s/#/frameworks/%s )',
            master_info.hostname, master_info.port, framework_id.value
        )

        if self.containerizer_type is None:
            version = tuple(int(x) for x in driver.version.split("."))
            self.containerizer_type = (
                'MESOS' if version >= (1, 0, 0) else 'DOCKER'
            )

    def statusUpdate(self, driver, update):
        logger.debug('Received status update %s', str(update.state))
        mesos_task_id = update.task_id.value
        if self._is_terminal_state(update.state):
            task = self.tasks.get(mesos_task_id)
            if task is None:
                # This should be very rare and hence making this info.
                logger.info("Task not found for mesos task id {}"
                            .format(mesos_task_id))
                return
            if self.started:
                if update.state != 'TASK_FINISHED':
                    logger.error('Task failed: %s, %s with state %s', task,
                                 update.message, update.state)
                    raise RuntimeError(
                        'Task %s failed! %s with state %s' %
                        (task, update.message, update.state)
                    )
                else:
                    self.job_finished[task.job_name] += 1
            else:
                logger.warn('Task failed while launching the server: %s, '
                            '%s with state %s', task,
                            update.message, update.state)

                if task.connection:
                    task.connection.close()

                self.task_failure_count[self.decorated_task_index(task)] += 1

                if self._can_revive_task(task):
                    self.revive_task(driver, mesos_task_id, task)
                else:
                    raise RuntimeError('Task %s failed %s with state %s and '
                                       'retries=%s' %
                                       (task, update.message, update.state,
                                        TFMesosScheduler.MAX_FAILURE_COUNT))

    def revive_task(self, driver, mesos_task_id, task):
        logger.info('Going to revive task %s ', task.task_index)
        self.tasks.pop(mesos_task_id)
        task.offered = False
        task.addr = None
        task.connection = None
        new_task_id = task.mesos_task_id = str(uuid.uuid4())
        self.tasks[new_task_id] = task
        driver.reviveOffers()

    def _can_revive_task(self, task):
        return self.task_failure_count[self.decorated_task_index(task)] < \
            TFMesosScheduler.MAX_FAILURE_COUNT

    @staticmethod
    def decorated_task_index(task):
        return '{}.{}'.format(task.job_name, str(task.task_index))

    @staticmethod
    def _is_terminal_state(task_state):
        return task_state in ["TASK_FINISHED", "TASK_FAILED", "TASK_KILLED",
                              "TASK_ERROR"]

    def slaveLost(self, driver, agent_id):
        if self.started:
            logger.error('Slave %s lost:', agent_id.value)
            raise RuntimeError('Slave %s lost' % agent_id)

    def executorLost(self, driver, executor_id, agent_id, status):
        if self.started:
            logger.error('Executor %s lost:', executor_id.value)
            raise RuntimeError('Executor %s@%s lost' % (executor_id, agent_id))

    def error(self, driver, message):
        logger.error('Mesos error: %s', message)
        raise RuntimeError('Error ' + message)

    def stop(self):
        logger.debug('exit')

        if hasattr(self, 'tasks'):
            for id, task in iteritems(self.tasks):
                if task.connection:
                    task.connection.close()

            del self.tasks

        if hasattr(self, 'driver'):
            self.driver.stop()
            self.driver.join()
            del self.driver

    def finished(self):
        return any(
            self.job_finished[job.name] >= job.num for job in self.task_spec
        )

    def processHeartBeat(self):
        # compatibility with pymesos
        pass
コード例 #8
0
class MesosExecutor(TaskExecutor):
    def __init__(
        self,
        role: str,
        callbacks: MesosExecutorCallbacks,
        pool=None,
        principal='taskproc',
        secret=None,
        mesos_address='127.0.0.1:5050',
        initial_decline_delay=1.0,
        framework_name='taskproc-default',
        framework_staging_timeout=240,
        framework_id=None,
        failover=False,
    ) -> None:
        """
        Constructs the instance of a task execution, encapsulating all state
        required to run, monitor and stop the job.

        TODO param docstrings
        """

        self.logger = logging.getLogger(__name__)
        self.role = role
        self.failover = failover

        self.execution_framework = ExecutionFramework(
            role=role,
            pool=pool,
            name=framework_name,
            callbacks=callbacks,
            task_staging_timeout_s=framework_staging_timeout,
            initial_decline_delay=initial_decline_delay,
            framework_id=framework_id,
        )

        # TODO: Get mesos master ips from smartstack
        self.driver = MesosSchedulerDriver(
            sched=self.execution_framework,
            framework=self.execution_framework.framework_info,
            use_addict=True,
            master_uri=mesos_address,
            implicit_acknowledgements=False,
            principal=principal,
            secret=secret,
            failover=failover,
        )

        # start driver thread immediately
        self.stopping = False
        self.driver_thread = threading.Thread(target=self._run_driver, args=())
        self.driver_thread.daemon = True
        self.driver_thread.start()

    def _run_driver(self):
        while not self.stopping:
            self.driver.run()
            self.logger.warning('Driver stopped, starting again')

    def run(self, task_config):
        self.execution_framework.enqueue_task(task_config)

    def reconcile(self, task_config):
        self.execution_framework.reconcile_task(task_config)

    def kill(self, task_id):
        return self.execution_framework.kill_task(task_id)

    def stop(self):
        self.stopping = True
        self.execution_framework.stop()
        self.driver.stop(failover=self.failover)
        self.driver.join()

    def get_event_queue(self):
        return self.execution_framework.event_queue
コード例 #9
0
ファイル: scheduler.py プロジェクト: windreamer/tfmesos
class TFMesosScheduler(Scheduler):
    MAX_FAILURE_COUNT = 3

    def __init__(self,
                 task_spec,
                 role=None,
                 master=None,
                 name=None,
                 quiet=False,
                 volumes={},
                 containerizer_type=None,
                 force_pull_image=False,
                 forward_addresses=None,
                 protocol='grpc',
                 env={},
                 extra_config={}):
        self.started = False
        self.master = master or os.environ['MESOS_MASTER']
        self.name = name or '[tensorflow] %s %s' % (os.path.abspath(
            sys.argv[0]), ' '.join(sys.argv[1:]))
        self.task_spec = task_spec
        self.containerizer_type = containerizer_type
        self.force_pull_image = force_pull_image
        self.protocol = protocol
        self.extra_config = extra_config
        self.forward_addresses = forward_addresses
        self.role = role or '*'
        self.tasks = {}
        self.task_failure_count = {}
        self.job_finished = {}
        for job in task_spec:
            self.job_finished[job.name] = 0
            for task_index in range(job.start, job.num):
                mesos_task_id = str(uuid.uuid4())
                task = Task(mesos_task_id,
                            job.name,
                            task_index,
                            cpus=job.cpus,
                            mem=job.mem,
                            gpus=job.gpus,
                            cmd=job.cmd,
                            volumes=volumes,
                            env=env)
                self.tasks[mesos_task_id] = task
                self.task_failure_count[self.decorated_task_index(task)] = 0

        if not quiet:
            global logger
            setup_logger(logger)

    def resourceOffers(self, driver, offers):
        '''
        Offer resources and launch tasks
        '''

        for offer in offers:
            if all(task.offered for id, task in iteritems(self.tasks)):
                self.driver.suppressOffers()
                driver.declineOffer(offer.id, Dict(refuse_seconds=FOREVER))
                continue

            offered_cpus = offered_mem = 0.0
            offered_gpus = []
            offered_tasks = []
            gpu_resource_type = None

            for resource in offer.resources:
                if resource.name == 'cpus':
                    offered_cpus = resource.scalar.value
                elif resource.name == 'mem':
                    offered_mem = resource.scalar.value
                elif resource.name == 'gpus':
                    if resource.type == 'SET':
                        offered_gpus = resource.set.item
                    else:
                        offered_gpus = list(range(int(resource.scalar.value)))

                    gpu_resource_type = resource.type

            for id, task in iteritems(self.tasks):
                if task.offered:
                    continue

                if not (task.cpus <= offered_cpus and task.mem <= offered_mem
                        and task.gpus <= len(offered_gpus)):

                    continue

                offered_cpus -= task.cpus
                offered_mem -= task.mem
                gpus = int(math.ceil(task.gpus))
                gpu_uuids = offered_gpus[:gpus]
                offered_gpus = offered_gpus[gpus:]
                task.offered = True
                offered_tasks.append(
                    task.to_task_info(
                        offer,
                        self.addr,
                        gpu_uuids=gpu_uuids,
                        gpu_resource_type=gpu_resource_type,
                        containerizer_type=self.containerizer_type,
                        force_pull_image=self.force_pull_image))

            driver.launchTasks(offer.id, offered_tasks)

    @property
    def targets(self):
        targets = {}
        for id, task in iteritems(self.tasks):
            target_name = '/job:%s/task:%s' % (task.job_name, task.task_index)
            grpc_addr = 'grpc://%s' % task.addr
            targets[target_name] = grpc_addr
        return targets

    def _start_tf_cluster(self):
        cluster_def = {}

        tasks = sorted(self.tasks.values(), key=lambda task: task.task_index)
        for task in tasks:
            cluster_def.setdefault(task.job_name, []).append(task.addr)

        for id, task in iteritems(self.tasks):
            response = {
                'job_name': task.job_name,
                'task_index': task.task_index,
                'cpus': task.cpus,
                'mem': task.mem,
                'gpus': task.gpus,
                'cmd': task.cmd,
                'cwd': os.getcwd(),
                'cluster_def': cluster_def,
                'forward_addresses': self.forward_addresses,
                'extra_config': self.extra_config,
                'protocol': self.protocol
            }
            send(task.connection, response)
            assert recv(task.connection) == 'ok'
            logger.info('Device /job:%s/task:%s activated @ grpc://%s ',
                        task.job_name, task.task_index, task.addr)
            task.connection.close()

    def start(self):
        def readable(fd):
            return bool(select.select([fd], [], [], 0.1)[0])

        lfd = socket.socket()
        try:
            lfd.bind(('', 0))
            self.addr = '%s:%s' % (socket.gethostname(), lfd.getsockname()[1])
            lfd.listen(10)
            framework = Dict()
            framework.user = getpass.getuser()
            framework.name = self.name
            framework.hostname = socket.gethostname()
            framework.role = self.role

            self.driver = MesosSchedulerDriver(self,
                                               framework,
                                               self.master,
                                               use_addict=True)
            self.driver.start()
            task_start_count = 0
            while any(
                (not task.initalized for id, task in iteritems(self.tasks))):
                if readable(lfd):
                    c, _ = lfd.accept()
                    if readable(c):
                        mesos_task_id, addr = recv(c)
                        task = self.tasks[mesos_task_id]
                        task.addr = addr
                        task.connection = c
                        task.initalized = True
                        task_start_count += 1
                        logger.info(
                            'Task %s with mesos_task_id %s has '
                            'registered',
                            '{}:{}'.format(task.job_name,
                                           task.task_index), mesos_task_id)
                        logger.info(
                            'Out of %d tasks '
                            '%d tasks have been registered', len(self.tasks),
                            task_start_count)
                    else:
                        c.close()

            self.started = True
            self._start_tf_cluster()
        except Exception:
            self.stop()
            raise
        finally:
            lfd.close()

    def registered(self, driver, framework_id, master_info):
        logger.info(
            'Tensorflow cluster registered. '
            '( http://%s:%s/#/frameworks/%s )', master_info.hostname,
            master_info.port, framework_id.value)

        if self.containerizer_type is None:
            version = tuple(int(x) for x in driver.version.split("."))
            self.containerizer_type = ('MESOS' if version >=
                                       (1, 0, 0) else 'DOCKER')

    def statusUpdate(self, driver, update):
        logger.debug('Received status update %s', str(update.state))
        mesos_task_id = update.task_id.value
        if self._is_terminal_state(update.state):
            task = self.tasks.get(mesos_task_id)
            if task is None:
                # This should be very rare and hence making this info.
                logger.info("Task not found for mesos task id {}".format(
                    mesos_task_id))
                return
            if self.started:
                if update.state != 'TASK_FINISHED':
                    logger.error('Task failed: %s, %s with state %s', task,
                                 update.message, update.state)
                    raise RuntimeError('Task %s failed! %s with state %s' %
                                       (task, update.message, update.state))
                else:
                    self.job_finished[task.job_name] += 1
            else:
                logger.warn(
                    'Task failed while launching the server: %s, '
                    '%s with state %s', task, update.message, update.state)

                if task.connection:
                    task.connection.close()

                self.task_failure_count[self.decorated_task_index(task)] += 1

                if self._can_revive_task(task):
                    self.revive_task(driver, mesos_task_id, task)
                else:
                    raise RuntimeError('Task %s failed %s with state %s and '
                                       'retries=%s' %
                                       (task, update.message, update.state,
                                        TFMesosScheduler.MAX_FAILURE_COUNT))

    def revive_task(self, driver, mesos_task_id, task):
        logger.info('Going to revive task %s ', task.task_index)
        self.tasks.pop(mesos_task_id)
        task.offered = False
        task.addr = None
        task.connection = None
        new_task_id = task.mesos_task_id = str(uuid.uuid4())
        self.tasks[new_task_id] = task
        driver.reviveOffers()

    def _can_revive_task(self, task):
        return self.task_failure_count[self.decorated_task_index(task)] < \
            TFMesosScheduler.MAX_FAILURE_COUNT

    @staticmethod
    def decorated_task_index(task):
        return '{}.{}'.format(task.job_name, str(task.task_index))

    @staticmethod
    def _is_terminal_state(task_state):
        return task_state in [
            "TASK_FINISHED", "TASK_FAILED", "TASK_KILLED", "TASK_ERROR"
        ]

    def slaveLost(self, driver, agent_id):
        if self.started:
            logger.error('Slave %s lost:', agent_id.value)
            raise RuntimeError('Slave %s lost' % agent_id)

    def executorLost(self, driver, executor_id, agent_id, status):
        if self.started:
            logger.error('Executor %s lost:', executor_id.value)
            raise RuntimeError('Executor %s@%s lost' % (executor_id, agent_id))

    def error(self, driver, message):
        logger.error('Mesos error: %s', message)
        raise RuntimeError('Error ' + message)

    def stop(self):
        logger.debug('exit')

        if hasattr(self, 'tasks'):
            for id, task in iteritems(self.tasks):
                if task.connection:
                    task.connection.close()

            del self.tasks

        if hasattr(self, 'driver'):
            self.driver.stop()
            self.driver.join()
            del self.driver

    def finished(self):
        return any(self.job_finished[job.name] >= job.num
                   for job in self.task_spec)
コード例 #10
0
ファイル: scheduler.py プロジェクト: tjsongzw/tfmesos
class TFMesosScheduler(Scheduler):

    def __init__(self, task_spec, master=None, name=None, quiet=False,
                 volumes={}, local_task=None):
        self.started = False
        self.master = master or os.environ['MESOS_MASTER']
        self.name = name or '[tensorflow] %s %s' % (
            os.path.abspath(sys.argv[0]), ' '.join(sys.argv[1:]))
        self.local_task = local_task
        self.task_spec = task_spec
        self.tasks = []
        for job in task_spec:
            for task_index in range(job.start, job.num):
                mesos_task_id = len(self.tasks)
                self.tasks.append(
                    Task(
                        mesos_task_id,
                        job.name,
                        task_index,
                        cpus=job.cpus,
                        mem=job.mem,
                        gpus=job.gpus,
                        cmd=job.cmd,
                        volumes=volumes
                    )
                )
        if not quiet:
            global logger
            setup_logger(logger)

    def resourceOffers(self, driver, offers):
        '''
        Offer resources and launch tasks
        '''

        for offer in offers:
            if all(task.offered for task in self.tasks):
                driver.declineOffer(offer.id, Dict(refuse_seconds=FOREVER))
                continue

            offered_cpus = offered_mem = 0.0
            offered_gpus = []
            offered_tasks = []
            gpu_resource_type = None

            for resource in offer.resources:
                if resource.name == 'cpus':
                    offered_cpus = resource.scalar.value
                elif resource.name == 'mem':
                    offered_mem = resource.scalar.value
                elif resource.name == 'gpus':
                    if resource.type == 'SET':
                        offered_gpus = resource.set.item
                    else:
                        offered_gpus = list(range(int(resource.scalar.value)))

                    gpu_resource_type = resource.type

            for task in self.tasks:
                if task.offered:
                    continue

                if not (task.cpus <= offered_cpus and
                        task.mem <= offered_mem and
                        task.gpus <= len(offered_gpus)):

                    continue

                offered_cpus -= task.cpus
                offered_mem -= task.mem
                gpus = int(math.ceil(task.gpus))
                gpu_uuids = offered_gpus[:gpus]
                offered_gpus = offered_gpus[gpus:]
                task.offered = True
                offered_tasks.append(
                    task.to_task_info(
                        offer, self.addr, gpu_uuids=gpu_uuids,
                        gpu_resource_type=gpu_resource_type
                    )
                )

            driver.launchTasks(offer.id, offered_tasks)

    def _start_tf_cluster(self):
        cluster_def = {}

        targets = {}
        for task in self.tasks:
            target_name = '/job:%s/task:%s' % (task.job_name, task.task_index)
            grpc_addr = 'grpc://%s' % task.addr
            targets[target_name] = grpc_addr
            cluster_def.setdefault(task.job_name, []).append(task.addr)

        if self.local_task:
            job_name, addr = self.local_task
            cluster_def.setdefault(job_name, []).insert(0, addr)

        for task in self.tasks:
            response = {
                'job_name': task.job_name,
                'task_index': task.task_index,
                'cpus': task.cpus,
                'mem': task.mem,
                'gpus': task.gpus,
                'cmd': task.cmd,
                'cwd': os.getcwd(),
                'cluster_def': cluster_def,
            }
            send(task.connection, response)
            assert recv(task.connection) == 'ok'
            logger.info(
                'Device /job:%s/task:%s activated @ grpc://%s ',
                task.job_name,
                task.task_index,
                task.addr

            )
            task.connection.close()
        return targets

    def start(self):

        def readable(fd):
            return bool(select.select([fd], [], [], 0.1)[0])

        lfd = socket.socket()
        try:
            lfd.bind(('', 0))
            self.addr = '%s:%s' % (socket.gethostname(), lfd.getsockname()[1])
            lfd.listen(10)
            framework = Dict()
            framework.user = getpass.getuser()
            framework.name = self.name
            framework.hostname = socket.gethostname()

            self.driver = MesosSchedulerDriver(
                self, framework, self.master, use_addict=True
            )
            self.driver.start()
            while any((not task.initalized for task in self.tasks)):
                if readable(lfd):
                    c, _ = lfd.accept()
                    if readable(c):
                        mesos_task_id, addr = recv(c)
                        assert isinstance(mesos_task_id, int)
                        task = self.tasks[mesos_task_id]
                        task.addr = addr
                        task.connection = c
                        task.initalized = True
                    else:
                        c.close()

            self.started = True
            return self._start_tf_cluster()
        except Exception:
            self.stop()
            raise
        finally:
            lfd.close()

    def registered(self, driver, framework_id, master_info):
        logger.info(
            'Tensorflow cluster registered. '
            '( http://%s:%s/#/frameworks/%s )',
            master_info.hostname, master_info.port, framework_id.value
        )

    def statusUpdate(self, driver, update):
        mesos_task_id = int(update.task_id.value)
        if update.state != 'TASK_RUNNING':
            task = self.tasks[mesos_task_id]
            if self.started:
                if update.state != 'TASK_FINISHED':
                    logger.error('Task failed: %s, %s', task, update.message)
                    raise RuntimeError(
                        'Task %s failed! %s' % (id, update.message)
                    )
            else:
                logger.warn('Task failed: %s, %s', task, update.message)
                if task.connection:
                    task.connection.close()

                driver.reviveOffers()

    def slaveLost(self, driver, agent_id):
        if self.started:
            logger.error('Slave %s lost:', agent_id.value)
            raise RuntimeError('Slave %s lost' % agent_id)

    def executorLost(self, driver, executor_id, agent_id, status):
        if self.started:
            logger.error('Executor %s lost:', executor_id.value)
            raise RuntimeError('Executor %s@%s lost' % (executor_id, agent_id))

    def error(self, driver, message):
        logger.error('Mesos error: %s', message)
        raise RuntimeError('Error ' + message)

    def stop(self):
        logger.debug('exit')

        if hasattr(self, 'tasks'):
            for task in getattr(self, 'tasks', []):
                if task.connection:
                    task.connection.close()

            del self.tasks

        if hasattr(self, 'driver'):
            self.driver.stop()
            del self.driver
コード例 #11
0
class TFMesosScheduler(Scheduler):
    def __init__(self,
                 task_spec,
                 master=None,
                 name=None,
                 quiet=False,
                 volumes={},
                 containerizer_type=None,
                 forward_addresses=None,
                 protocol='grpc'):
        self.started = False
        self.master = master or os.environ['MESOS_MASTER']
        self.name = name or '[tensorflow] %s %s' % (os.path.abspath(
            sys.argv[0]), ' '.join(sys.argv[1:]))
        self.task_spec = task_spec
        self.containerizer_type = containerizer_type
        self.protocol = protocol
        self.forward_addresses = forward_addresses
        self.tasks = []
        self.job_finished = {}
        for job in task_spec:
            self.job_finished[job.name] = 0
            for task_index in range(job.start, job.num):
                mesos_task_id = len(self.tasks)
                self.tasks.append(
                    Task(mesos_task_id,
                         job.name,
                         task_index,
                         cpus=job.cpus,
                         mem=job.mem,
                         gpus=job.gpus,
                         cmd=job.cmd,
                         volumes=volumes))
        if not quiet:
            global logger
            setup_logger(logger)

    def resourceOffers(self, driver, offers):
        '''
        Offer resources and launch tasks
        '''

        for offer in offers:
            if all(task.offered for task in self.tasks):
                self.driver.suppressOffers()
                driver.declineOffer(offer.id, Dict(refuse_seconds=FOREVER))
                continue

            offered_cpus = offered_mem = 0.0
            offered_gpus = []
            offered_tasks = []
            gpu_resource_type = None

            for resource in offer.resources:
                if resource.name == 'cpus':
                    offered_cpus = resource.scalar.value
                elif resource.name == 'mem':
                    offered_mem = resource.scalar.value
                elif resource.name == 'gpus':
                    if resource.type == 'SET':
                        offered_gpus = resource.set.item
                    else:
                        offered_gpus = list(range(int(resource.scalar.value)))

                    gpu_resource_type = resource.type

            for task in self.tasks:
                if task.offered:
                    continue

                if not (task.cpus <= offered_cpus and task.mem <= offered_mem
                        and task.gpus <= len(offered_gpus)):

                    continue

                offered_cpus -= task.cpus
                offered_mem -= task.mem
                gpus = int(math.ceil(task.gpus))
                gpu_uuids = offered_gpus[:gpus]
                offered_gpus = offered_gpus[gpus:]
                task.offered = True
                offered_tasks.append(
                    task.to_task_info(
                        offer,
                        self.addr,
                        gpu_uuids=gpu_uuids,
                        gpu_resource_type=gpu_resource_type,
                        containerizer_type=self.containerizer_type))

            driver.launchTasks(offer.id, offered_tasks)

    @property
    def targets(self):
        targets = {}
        for task in self.tasks:
            target_name = '/job:%s/task:%s' % (task.job_name, task.task_index)
            grpc_addr = 'grpc://%s' % task.addr
            targets[target_name] = grpc_addr
        return targets

    def _start_tf_cluster(self):
        cluster_def = {}

        for task in self.tasks:
            cluster_def.setdefault(task.job_name, []).append(task.addr)

        for task in self.tasks:
            response = {
                'job_name': task.job_name,
                'task_index': task.task_index,
                'cpus': task.cpus,
                'mem': task.mem,
                'gpus': task.gpus,
                'cmd': task.cmd,
                'cwd': os.getcwd(),
                'cluster_def': cluster_def,
                'forward_addresses': self.forward_addresses,
                'protocol': self.protocol
            }
            send(task.connection, response)
            assert recv(task.connection) == 'ok'
            logger.info('Device /job:%s/task:%s activated @ grpc://%s ',
                        task.job_name, task.task_index, task.addr)
            task.connection.close()

    def start(self):
        def readable(fd):
            return bool(select.select([fd], [], [], 0.1)[0])

        lfd = socket.socket()
        try:
            lfd.bind(('', 0))
            self.addr = '%s:%s' % (socket.gethostname(), lfd.getsockname()[1])
            lfd.listen(10)
            framework = Dict()
            framework.user = getpass.getuser()
            framework.name = self.name
            framework.hostname = socket.gethostname()

            self.driver = MesosSchedulerDriver(self,
                                               framework,
                                               self.master,
                                               use_addict=True)
            self.driver.start()
            while any((not task.initalized for task in self.tasks)):
                if readable(lfd):
                    c, _ = lfd.accept()
                    if readable(c):
                        mesos_task_id, addr = recv(c)
                        assert isinstance(mesos_task_id, int)
                        task = self.tasks[mesos_task_id]
                        task.addr = addr
                        task.connection = c
                        task.initalized = True
                    else:
                        c.close()

            self.started = True
            self._start_tf_cluster()
        except Exception:
            self.stop()
            raise
        finally:
            lfd.close()

    def registered(self, driver, framework_id, master_info):
        logger.info(
            'Tensorflow cluster registered. '
            '( http://%s:%s/#/frameworks/%s )', master_info.hostname,
            master_info.port, framework_id.value)

        if self.containerizer_type is None:
            version = tuple(int(x) for x in driver.version.split("."))
            self.containerizer_type = ('MESOS' if version >=
                                       (1, 0, 0) else 'DOCKER')

    def statusUpdate(self, driver, update):
        mesos_task_id = int(update.task_id.value)
        if update.state != 'TASK_RUNNING':
            task = self.tasks[mesos_task_id]
            if self.started:
                if update.state != 'TASK_FINISHED':
                    logger.error('Task failed: %s, %s', task, update.message)
                    raise RuntimeError('Task %s failed! %s' %
                                       (task, update.message))
                else:
                    self.job_finished[task.job_name] += 1
            else:
                logger.warn('Task failed: %s, %s', task, update.message)
                if task.connection:
                    task.connection.close()

                driver.reviveOffers()

    def slaveLost(self, driver, agent_id):
        if self.started:
            logger.error('Slave %s lost:', agent_id.value)
            raise RuntimeError('Slave %s lost' % agent_id)

    def executorLost(self, driver, executor_id, agent_id, status):
        if self.started:
            logger.error('Executor %s lost:', executor_id.value)
            raise RuntimeError('Executor %s@%s lost' % (executor_id, agent_id))

    def error(self, driver, message):
        logger.error('Mesos error: %s', message)
        raise RuntimeError('Error ' + message)

    def stop(self):
        logger.debug('exit')

        if hasattr(self, 'tasks'):
            for task in getattr(self, 'tasks', []):
                if task.connection:
                    task.connection.close()

            del self.tasks

        if hasattr(self, 'driver'):
            self.driver.stop()
            self.driver.join()
            del self.driver

    def finished(self):
        return any(self.job_finished[job.name] >= job.num
                   for job in self.task_spec)
コード例 #12
0
class ProcScheduler(Scheduler):
    def __init__(self):
        self.framework_id = None
        self.framework = self._init_framework()
        self.executor = None
        self.master = str(CONFIG.get('master', os.environ['MESOS_MASTER']))
        self.driver = MesosSchedulerDriver(self, self.framework, self.master)
        self.procs_pending = {}
        self.procs_launched = {}
        self.agent_to_proc = {}
        self._lock = RLock()

    def _init_framework(self):
        framework = dict(
            user=getpass.getuser(),
            name=repr(self),
            hostname=socket.gethostname(),
        )
        return framework

    def _init_executor(self):
        executor = dict(
            executor_id=dict(value='default'),
            framework_id=self.framework_id,
            command=dict(value='%s -m %s.executor' %
                         (sys.executable, __package__)),
            resources=[
                dict(
                    name='mem',
                    type='SCALAR',
                    scalar=dict(value=MIN_MEMORY),
                ),
                dict(name='cpus', type='SCALAR', scalar=dict(value=MIN_CPUS)),
            ],
        )

        if 'PYTHONPATH' in os.environ:
            executor['command.environment'] = dict(variables=[
                dict(
                    name='PYTHONPATH',
                    value=os.environ['PYTHONPATH'],
                ),
            ])

        return executor

    def _init_task(self, proc, offer):
        resources = [
            dict(
                name='cpus',
                type='SCALAR',
                scalar=dict(value=proc.cpus),
            ),
            dict(
                name='mem',
                type='SCALAR',
                scalar=dict(value=proc.mem),
            )
        ]

        if proc.gpus > 0:
            resources.append(
                dict(
                    name='gpus',
                    type='SCALAR',
                    scalar=dict(value=proc.gpus),
                ))

        task = dict(
            task_id=dict(value=str(proc.id)),
            name=repr(proc),
            executor=self.executor,
            agent_id=offer['agent_id'],
            data=b2a_base64(pickle.dumps(proc.params)).strip(),
            resources=resources,
        )

        return task

    def _filters(self, seconds):
        f = dict(refuse_seconds=seconds)
        return f

    def __repr__(self):
        return "%s[%s]: %s" % (self.__class__.__name__, os.getpid(), ' '.join(
            sys.argv))

    def registered(self, driver, framework_id, master_info):
        with self._lock:
            logger.info('Framework registered with id=%s, master=%s' %
                        (framework_id, master_info))
            self.framework_id = framework_id
            self.executor = self._init_executor()

    def resourceOffers(self, driver, offers):
        def get_resources(offer):
            cpus, mem, gpus = 0.0, 0.0, 0
            for r in offer['resources']:
                if r['name'] == 'cpus':
                    cpus = float(r['scalar']['value'])
                elif r['name'] == 'mem':
                    mem = float(r['scalar']['value'])
                elif r['name'] == 'gpus':
                    gpus = int(r['scalar']['value'])

            return cpus, mem, gpus

        with self._lock:
            random.shuffle(offers)
            for offer in offers:
                if not self.procs_pending:
                    logger.debug('Reject offers forever for no pending procs, '
                                 'offers=%s' % (offers, ))
                    driver.declineOffer(offer['id'], self._filters(FOREVER))
                    continue

                cpus, mem, gpus = get_resources(offer)
                tasks = []
                for proc in list(self.procs_pending.values()):
                    if (cpus >= proc.cpus + MIN_CPUS
                            and mem >= proc.mem + MIN_MEMORY
                            and gpus >= proc.gpus):
                        tasks.append(self._init_task(proc, offer))
                        del self.procs_pending[proc.id]
                        self.procs_launched[proc.id] = proc
                        cpus -= proc.cpus
                        mem -= proc.mem
                        gpus -= proc.gpus

                seconds = 5 + random.random() * 5
                if tasks:
                    logger.info(
                        'Accept offer for procs, offer=%s, '
                        'procs=%s, filter_time=%s' %
                        (offer, [int(t['task_id']['value'])
                                 for t in tasks], seconds))
                    driver.launchTasks(offer['id'], tasks,
                                       self._filters(seconds))
                else:
                    logger.info('Retry offer for procs later, offer=%s, '
                                'filter_time=%s' % (offer, seconds))
                    driver.declineOffer(offer['id'], self._filters(seconds))

    def _call_finished(self, proc_id, success, message, data, agent_id=None):
        with self._lock:
            proc = self.procs_launched.pop(proc_id)
            if agent_id is not None:
                if agent_id in self.agent_to_proc:
                    self.agent_to_proc[agent_id].remove(proc_id)
            else:
                for agent_id, procs in list(self.agent_to_proc.items()):
                    if proc_id in procs:
                        procs.remove(proc_id)

            proc._finished(success, message, data)

    def statusUpdate(self, driver, update):
        with self._lock:
            proc_id = int(update['task_id']['value'])
            logger.info('Status update for proc, id=%s, state=%s' %
                        (proc_id, update['state']))
            agent_id = update['agent_id']['value']
            if update['state'] == 'TASK_RUNNING':
                if agent_id in self.agent_to_proc:
                    self.agent_to_proc[agent_id].add(proc_id)
                else:
                    self.agent_to_proc[agent_id] = set([proc_id])

                proc = self.procs_launched[proc_id]
                proc._started()

            elif update['state'] not in {
                    'TASK_STAGING', 'TASK_STARTING', 'TASK_RUNNING'
            }:
                success = (update['state'] == 'TASK_FINISHED')
                message = update.get('message')
                data = update.get('data')
                if data:
                    data = pickle.loads(a2b_base64(data))

                self._call_finished(proc_id, success, message, data, agent_id)
                driver.reviveOffers()

    def offerRescinded(self, driver, offer_id):
        with self._lock:
            if self.procs_pending:
                logger.info('Revive offers for pending procs')
                driver.reviveOffers()

    def executorLost(self, driver, executor_id, agent_id, status):
        agent_id = agent_id['value']
        with self._lock:
            for proc_id in self.agent_to_proc.pop(agent_id, []):
                self._call_finished(proc_id, False, 'Executor lost', None,
                                    agent_id)

    def slaveLost(self, driver, agent_id):
        agent_id = agent_id['value']
        with self._lock:
            for proc_id in self.agent_to_proc.pop(agent_id, []):
                self._call_finished(proc_id, False, 'Agent lost', None,
                                    agent_id)

    def error(self, driver, message):
        with self._lock:
            for proc in list(self.procs_pending.values()):
                self._call_finished(proc.id, False, message, None)

            for proc in list(self.procs_launched.values()):
                self._call_finished(proc.id, False, message, None)

        self.stop()

    def start(self):
        self.driver.start()

    def stop(self):
        assert not self.driver.aborted
        self.driver.stop()

    def submit(self, proc):
        if self.driver.aborted:
            raise RuntimeError('driver already aborted')

        with self._lock:
            if proc.id not in self.procs_pending:
                logger.info('Try submit proc, id=%s', (proc.id, ))
                self.procs_pending[proc.id] = proc
                if len(self.procs_pending) == 1:
                    logger.info('Revive offers for pending procs')
                    self.driver.reviveOffers()
            else:
                raise ValueError('Proc with same id already submitted')

    def cancel(self, proc):
        if self.driver.aborted:
            raise RuntimeError('driver already aborted')

        with self._lock:
            if proc.id in self.procs_pending:
                del self.procs_pending[proc.id]
            elif proc.id in self.procs_launched:
                del self.procs_launched[proc.id]
                self.driver.killTask(dict(value=str(proc.id)))

            for agent_id, procs in list(self.agent_to_proc.items()):
                procs.pop(proc.id)
                if not procs:
                    del self.agent_to_proc[agent_id]

    def send_data(self, pid, type, data):
        if self.driver.aborted:
            raise RuntimeError('driver already aborted')

        msg = b2a_base64(pickle.dumps((pid, type, data)))
        for agent_id, procs in list(self.agent_to_proc.items()):
            if pid in procs:
                self.driver.sendFrameworkMessage(self.executor['executor_id'],
                                                 dict(value=agent_id), msg)
                return

        raise RuntimeError('Cannot find agent for pid %s' % (pid, ))
コード例 #13
0
ファイル: scheduler.py プロジェクト: douban/dpark
    signal.signal(signal.SIGABRT, handler)
    signal.signal(signal.SIGQUIT, handler)

    spawn_rconsole(locals())

    try:
        driver.start()
        sched.run(driver)
    except KeyboardInterrupt:
        logger.warning('stopped by KeyboardInterrupt')
        sched.stop(EXIT_KEYBORAD)
    except Exception as e:
        import traceback

        logger.warning('catch unexpected Exception, exit now. %s',
                       traceback.format_exc())
        sched.stop(EXIT_EXCEPTION)
    finally:
        try:
            sched.dump_stats()
        except:
            logger.exception("dump stats fail, ignore it.")
        # sched.lock may be in WRONG status.
        # if any thread of sched may use lock or call driver, join it first
        driver.stop(False)
        driver.join()
        # mesos resourses are released, and no racer for lock any more
        sched.cleanup()
        ctx.term()
        sys.exit(sched.ec)
コード例 #14
0
ファイル: schedule.py プロジェクト: zouzias/dpark
class MesosScheduler(DAGScheduler):
    def __init__(self, master, options):
        DAGScheduler.__init__(self)
        self.master = master
        self.use_self_as_exec = options.self
        self.cpus = options.cpus
        self.mem = options.mem
        self.task_per_node = options.parallel or multiprocessing.cpu_count()
        self.group = options.group
        self.logLevel = options.logLevel
        self.options = options
        self.started = False
        self.last_finish_time = 0
        self.isRegistered = False
        self.executor = None
        self.driver = None
        self.out_logger = None
        self.err_logger = None
        self.lock = threading.RLock()
        self.init_job()

    def init_job(self):
        self.activeJobs = {}
        self.activeJobsQueue = []
        self.taskIdToJobId = {}
        self.taskIdToAgentId = {}
        self.jobTasks = {}
        self.agentTasks = {}

    def clear(self):
        DAGScheduler.clear(self)
        self.init_job()

    def start(self):
        if not self.out_logger:
            self.out_logger = self.start_logger(sys.stdout)
        if not self.err_logger:
            self.err_logger = self.start_logger(sys.stderr)

    def start_driver(self):
        name = '[dpark] ' + \
            os.path.abspath(sys.argv[0]) + ' ' + ' '.join(sys.argv[1:])
        if len(name) > 256:
            name = name[:256] + '...'
        framework = Dict()
        framework.user = getuser()
        if framework.user == 'root':
            raise Exception('dpark is not allowed to run as \'root\'')
        framework.name = name
        framework.hostname = socket.gethostname()
        framework.webui_url = self.options.webui_url

        self.driver = MesosSchedulerDriver(self,
                                           framework,
                                           self.master,
                                           use_addict=True)
        self.driver.start()
        logger.debug('Mesos Scheudler driver started')

        self.started = True
        self.last_finish_time = time.time()

        def check():
            while self.started:
                now = time.time()
                if (not self.activeJobs
                        and now - self.last_finish_time > MAX_IDLE_TIME):
                    logger.info('stop mesos scheduler after %d seconds idle',
                                now - self.last_finish_time)
                    self.stop()
                    break
                time.sleep(1)

        spawn(check)

    def start_logger(self, output):
        sock = env.ctx.socket(zmq.PULL)
        port = sock.bind_to_random_port('tcp://0.0.0.0')

        def collect_log():
            while not self._shutdown:
                if sock.poll(1000, zmq.POLLIN):
                    line = sock.recv()
                    output.write(line)

        spawn(collect_log)

        host = socket.gethostname()
        addr = 'tcp://%s:%d' % (host, port)
        logger.debug('log collecter start at %s', addr)
        return addr

    @safe
    def registered(self, driver, frameworkId, masterInfo):
        self.isRegistered = True
        logger.debug('connect to master %s:%s, registered as %s',
                     masterInfo.hostname, masterInfo.port, frameworkId.value)
        self.executor = self.getExecutorInfo(str(frameworkId.value))

    @safe
    def reregistered(self, driver, masterInfo):
        logger.warning('re-connect to mesos master %s:%s', masterInfo.hostname,
                       masterInfo.port)

    @safe
    def disconnected(self, driver):
        logger.debug('framework is disconnected')

    def _get_container_image(self):
        return self.options.image

    @safe
    def getExecutorInfo(self, framework_id):
        info = Dict()
        info.framework_id.value = framework_id

        if self.use_self_as_exec:
            info.command.value = os.path.abspath(sys.argv[0])
            info.executor_id.value = sys.argv[0]
        else:
            info.command.value = '%s %s' % (
                sys.executable,
                os.path.abspath(
                    os.path.join(os.path.dirname(__file__), 'executor.py')))
            info.executor_id.value = 'default'

        info.command.environment.variables = variables = []

        v = Dict()
        variables.append(v)
        v.name = 'UID'
        v.value = str(os.getuid())

        v = Dict()
        variables.append(v)
        v.name = 'GID'
        v.value = str(os.getgid())

        container_image = self._get_container_image()
        if container_image:
            info.container.type = 'DOCKER'
            info.container.docker.image = container_image
            info.container.docker.parameters = parameters = []
            p = Dict()
            p.key = 'memory-swap'
            p.value = '-1'
            parameters.append(p)

            info.container.volumes = volumes = []
            for path in ['/etc/passwd', '/etc/group']:
                v = Dict()
                volumes.append(v)
                v.host_path = v.container_path = path
                v.mode = 'RO'

            for path in conf.MOOSEFS_MOUNT_POINTS:
                v = Dict()
                volumes.append(v)
                v.host_path = v.container_path = path
                v.mode = 'RW'

            for path in conf.DPARK_WORK_DIR.split(','):
                v = Dict()
                volumes.append(v)
                v.host_path = v.container_path = path
                v.mode = 'RW'

            def _mount_volume(volumes, host_path, container_path, mode):
                v = Dict()
                volumes.append(v)
                v.container_path = container_path
                v.mode = mode
                if host_path:
                    v.host_path = host_path

            if self.options.volumes:
                for volume in self.options.volumes.split(','):
                    fields = volume.split(':')
                    if len(fields) == 3:
                        host_path, container_path, mode = fields
                        mode = mode.upper()
                        assert mode in ('RO', 'RW')
                    elif len(fields) == 2:
                        host_path, container_path = fields
                        mode = 'RW'
                    elif len(fields) == 1:
                        container_path, = fields
                        host_path = ''
                        mode = 'RW'
                    else:
                        raise Exception('cannot parse volume %s', volume)
                    _mount_volume(volumes, host_path, container_path, mode)

        info.resources = resources = []

        mem = Dict()
        resources.append(mem)
        mem.name = 'mem'
        mem.type = 'SCALAR'
        mem.scalar.value = EXECUTOR_MEMORY

        cpus = Dict()
        resources.append(cpus)
        cpus.name = 'cpus'
        cpus.type = 'SCALAR'
        cpus.scalar.value = EXECUTOR_CPUS

        Script = os.path.realpath(sys.argv[0])
        info.name = Script

        info.data = encode_data(
            marshal.dumps((Script, os.getcwd(), sys.path, dict(os.environ),
                           self.task_per_node, self.out_logger,
                           self.err_logger, self.logLevel, env.environ)))
        return info

    @safe
    def submitTasks(self, tasks):
        if not tasks:
            return

        job = SimpleJob(self, tasks, self.cpus, tasks[0].rdd.mem or self.mem)
        self.activeJobs[job.id] = job
        self.activeJobsQueue.append(job)
        self.jobTasks[job.id] = set()
        stage_scope = ''
        try:
            from dpark.web.ui.views.rddopgraph import StageInfo
            stage_scope = StageInfo.idToRDDNode[
                tasks[0].rdd.id].scope.call_site
        except:
            pass
        stage = self.idToStage[tasks[0].stageId]
        stage.try_times += 1
        logger.info(
            'Got job %d with %d tasks for stage: %d(try %d times) '
            'at scope[%s] and rdd:%s', job.id, len(tasks), tasks[0].stageId,
            stage.try_times, stage_scope, tasks[0].rdd)

        need_revive = self.started
        if not self.started:
            self.start_driver()
        while not self.isRegistered:
            self.lock.release()
            time.sleep(0.01)
            self.lock.acquire()

        if need_revive:
            self.requestMoreResources()

    def requestMoreResources(self):
        logger.debug('reviveOffers')
        self.driver.reviveOffers()

    @safe
    def resourceOffers(self, driver, offers):
        rf = Dict()
        if not self.activeJobs:
            driver.suppressOffers()
            rf.refuse_seconds = 60 * 5
            for o in offers:
                driver.declineOffer(o.id, rf)
            return

        start = time.time()
        random.shuffle(offers)
        cpus = [self.getResource(o.resources, 'cpus') for o in offers]
        mems = [
            self.getResource(o.resources, 'mem') -
            (o.agent_id.value not in self.agentTasks and EXECUTOR_MEMORY or 0)
            for o in offers
        ]
        logger.debug('get %d offers (%s cpus, %s mem), %d jobs', len(offers),
                     sum(cpus), sum(mems), len(self.activeJobs))

        tasks = {}
        for job in self.activeJobsQueue:
            while True:
                launchedTask = False
                for i, o in enumerate(offers):
                    sid = o.agent_id.value
                    group = (self.getAttribute(o.attributes, 'group')
                             or 'None')
                    if (self.group or
                            group.startswith('_')) and group not in self.group:
                        continue
                    if self.agentTasks.get(sid, 0) >= self.task_per_node:
                        continue
                    if (mems[i] < self.mem + EXECUTOR_MEMORY
                            or cpus[i] < self.cpus + EXECUTOR_CPUS):
                        continue
                    t = job.slaveOffer(str(o.hostname), cpus[i], mems[i])
                    if not t:
                        continue
                    task = self.createTask(o, job, t)
                    tasks.setdefault(o.id.value, []).append(task)

                    logger.debug('dispatch %s into %s', t, o.hostname)
                    tid = task.task_id.value
                    self.jobTasks[job.id].add(tid)
                    self.taskIdToJobId[tid] = job.id
                    self.taskIdToAgentId[tid] = sid
                    self.agentTasks[sid] = self.agentTasks.get(sid, 0) + 1
                    cpus[i] -= min(cpus[i], t.cpus)
                    mems[i] -= t.mem
                    launchedTask = True

                if not launchedTask:
                    break

        used = time.time() - start
        if used > 10:
            logger.error('use too much time in resourceOffers: %.2fs', used)

        for o in offers:
            if o.id.value in tasks:
                driver.launchTasks(o.id, tasks[o.id.value])
            else:
                driver.declineOffer(o.id)

        logger.debug('reply with %d tasks, %s cpus %s mem left',
                     sum(len(ts) for ts in tasks.values()), sum(cpus),
                     sum(mems))

    @safe
    def offerRescinded(self, driver, offer_id):
        logger.debug('rescinded offer: %s', offer_id)
        if self.activeJobs:
            self.requestMoreResources()

    def getResource(self, res, name):
        for r in res:
            if r.name == name:
                return r.scalar.value
        return 0.0

    def getAttribute(self, attrs, name):
        for r in attrs:
            if r.name == name:
                return r.text.value

    def createTask(self, o, job, t):
        task = Dict()
        tid = '%s:%s:%s' % (job.id, t.id, t.tried)
        task.name = 'task %s' % tid
        task.task_id.value = tid
        task.agent_id.value = o.agent_id.value
        task.data = encode_data(
            compress(six.moves.cPickle.dumps((t, t.tried), -1)))
        task.executor = self.executor
        if len(task.data) > 1000 * 1024:
            logger.warning('task too large: %s %d', t, len(task.data))

        resources = task.resources = []

        cpu = Dict()
        resources.append(cpu)
        cpu.name = 'cpus'
        cpu.type = 'SCALAR'
        cpu.scalar.value = t.cpus

        mem = Dict()
        resources.append(mem)
        mem.name = 'mem'
        mem.type = 'SCALAR'
        mem.scalar.value = t.mem
        return task

    @safe
    def statusUpdate(self, driver, status):
        tid = status.task_id.value
        state = status.state
        logger.debug('status update: %s %s', tid, state)

        jid = self.taskIdToJobId.get(tid)
        _, task_id, tried = list(map(int, tid.split(':')))
        if state == 'TASK_RUNNING':
            if jid in self.activeJobs:
                job = self.activeJobs[jid]
                job.statusUpdate(task_id, tried, state)
            else:
                logger.debug('kill task %s as its job has gone', tid)
                self.driver.killTask(Dict(value=tid))

            return

        self.taskIdToJobId.pop(tid, None)
        if jid in self.jobTasks:
            self.jobTasks[jid].remove(tid)
        if tid in self.taskIdToAgentId:
            agent_id = self.taskIdToAgentId[tid]
            if agent_id in self.agentTasks:
                self.agentTasks[agent_id] -= 1
            del self.taskIdToAgentId[tid]

        if jid not in self.activeJobs:
            logger.debug('ignore task %s as its job has gone', tid)
            return

        job = self.activeJobs[jid]
        reason = status.get('message')
        data = status.get('data')
        if state in ('TASK_FINISHED', 'TASK_FAILED') and data:
            try:
                reason, result, accUpdate = six.moves.cPickle.loads(
                    decode_data(data))
                if result:
                    flag, data = result
                    if flag >= 2:
                        try:
                            data = urllib.request.urlopen(data).read()
                        except IOError:
                            # try again
                            data = urllib.request.urlopen(data).read()
                        flag -= 2
                    data = decompress(data)
                    if flag == 0:
                        result = marshal.loads(data)
                    else:
                        result = six.moves.cPickle.loads(data)
            except Exception as e:
                logger.warning('error when cPickle.loads(): %s, data:%s', e,
                               len(data))
                state = 'TASK_FAILED'
                return job.statusUpdate(task_id, tried, 'TASK_FAILED',
                                        'load failed: %s' % e)
            else:
                return job.statusUpdate(task_id, tried, state, reason, result,
                                        accUpdate)

        # killed, lost, load failed
        job.statusUpdate(task_id, tried, state, reason or data)

    def jobFinished(self, job):
        logger.debug('job %s finished', job.id)
        if job.id in self.activeJobs:
            self.last_finish_time = time.time()
            del self.activeJobs[job.id]
            self.activeJobsQueue.remove(job)
            for tid in self.jobTasks[job.id]:
                self.driver.killTask(Dict(value=tid))
            del self.jobTasks[job.id]

            if not self.activeJobs:
                self.agentTasks.clear()

        for tid, jid in six.iteritems(self.taskIdToJobId):
            if jid not in self.activeJobs:
                logger.debug('kill task %s, because it is orphan', tid)
                self.driver.killTask(Dict(value=tid))

    @safe
    def check(self):
        for job in self.activeJobs.values():
            if job.check_task_timeout():
                self.requestMoreResources()

    @safe
    def error(self, driver, message):
        logger.error('Mesos error message: %s', message)
        raise RuntimeError(message)

    # @safe
    def stop(self):
        if not self.started:
            return
        logger.debug('stop scheduler')
        self.started = False
        self.isRegistered = False
        self.driver.stop(False)
        self.driver = None

    def defaultParallelism(self):
        return 16

    def frameworkMessage(self, driver, executor_id, agent_id, data):
        logger.warning('[agent %s] %s', agent_id.value, data)

    def executorLost(self, driver, executor_id, agent_id, status):
        logger.warning('executor at %s %s lost: %s', agent_id.value,
                       executor_id.value, status)
        self.agentTasks.pop(agent_id.value, None)

    def slaveLost(self, driver, agent_id):
        logger.warning('agent %s lost', agent_id.value)
        self.agentTasks.pop(agent_id.value, None)

    def killTask(self, job_id, task_id, tried):
        tid = Dict()
        tid.value = '%s:%s:%s' % (job_id, task_id, tried)
        self.driver.killTask(tid)
コード例 #15
0
class MesosBatchSystem(BatchSystemLocalSupport,
                       AbstractScalableBatchSystem,
                       Scheduler):
    """
    A Toil batch system implementation that uses Apache Mesos to distribute toil jobs as Mesos
    tasks over a cluster of agent nodes. A Mesos framework consists of a scheduler and an
    executor. This class acts as the scheduler and is typically run on the master node that also
    runs the Mesos master process with which the scheduler communicates via a driver component.
    The executor is implemented in a separate class. It is run on each agent node and
    communicates with the Mesos agent process via another driver object. The scheduler may also
    be run on a separate node from the master, which we then call somewhat ambiguously the driver
    node.
    """

    @classmethod
    def supportsAutoDeployment(cls):
        return True

    @classmethod
    def supportsWorkerCleanup(cls):
        return True

    class ExecutorInfo:
        def __init__(self, nodeAddress, agentId, nodeInfo, lastSeen):
            super(MesosBatchSystem.ExecutorInfo, self).__init__()
            self.nodeAddress = nodeAddress
            self.agentId = agentId
            self.nodeInfo = nodeInfo
            self.lastSeen = lastSeen

    def __init__(self, config, maxCores, maxMemory, maxDisk):
        super().__init__(config, maxCores, maxMemory, maxDisk)

        # The auto-deployed resource representing the user script. Will be passed along in every
        # Mesos task. Also see setUserScript().
        self.userScript = None
        """
        :type: toil.resource.Resource
        """

        # Dictionary of queues, which toil assigns jobs to. Each queue represents a job type,
        # defined by resource usage
        self.jobQueues = JobQueue()

        # Address of the Mesos master in the form host:port where host can be an IP or a hostname
        self.mesos_endpoint = config.mesos_endpoint

        # Written to when Mesos kills tasks, as directed by Toil.
        # Jobs must not enter this set until they are removed from runningJobMap.
        self.killedJobIds = set()

        # The IDs of job to be killed
        self.killJobIds = set()

        # Contains jobs on which killBatchJobs were called, regardless of whether or not they
        # actually were killed or ended by themselves
        self.intendedKill = set()

        # Map of host address to job ids
        # this is somewhat redundant since Mesos returns the number of workers per
        # node. However, that information isn't guaranteed to reach the leader,
        # so we also track the state here. When the information is returned from
        # mesos, prefer that information over this attempt at state tracking.
        self.hostToJobIDs = {}

        # see self.setNodeFilter
        self.nodeFilter = []

        # Dict of launched jobIDs to TaskData objects
        self.runningJobMap = {}

        # Mesos has no easy way of getting a task's resources so we track them here
        self.taskResources = {}

        # Queue of jobs whose status has been updated, according to Mesos
        self.updatedJobsQueue = Queue()

        # The Mesos driver used by this scheduler
        self.driver = None

        # The string framework ID that we are assigned when registering with the Mesos master
        self.frameworkId = None

        # A dictionary mapping a node's IP to an ExecutorInfo object describing important
        # properties of our executor running on that node. Only an approximation of the truth.
        self.executors = {}

        # A dictionary mapping back from agent ID to the last observed IP address of its node.
        self.agentsByID = {}

        # A set of Mesos agent IDs, one for each agent running on a
        # non-preemptable node. Only an approximation of the truth. Recently
        # launched nodes may be absent from this set for a while and a node's
        # absence from this set does not imply its preemptability. But it is
        # generally safer to assume a node is preemptable since
        # non-preemptability is a stronger requirement. If we tracked the set
        # of preemptable nodes instead, we'd have to use absence as an
        # indicator of non-preemptability and could therefore be misled into
        # believing that a recently launched preemptable node was
        # non-preemptable.
        self.nonPreemptableNodes = set()

        self.executor = self._buildExecutor()

        # These control how frequently to log a message that would indicate if no jobs are
        # currently able to run on the offers given. This can happen if the cluster is busy
        # or if the nodes in the cluster simply don't have enough resources to run the jobs
        self.lastTimeOfferLogged = 0
        self.logPeriod = 30  # seconds

        self.ignoredNodes = set()

        self._startDriver()

    def setUserScript(self, userScript):
        self.userScript = userScript

    def ignoreNode(self, nodeAddress):
        self.ignoredNodes.add(nodeAddress)

    def unignoreNode(self, nodeAddress):
        self.ignoredNodes.remove(nodeAddress)

    def issueBatchJob(self, jobNode: JobDescription, job_environment: Optional[Dict[str, str]] = None):
        """
        Issues the following command returning a unique jobID. Command is the string to run, memory
        is an int giving the number of bytes the job needs to run in and cores is the number of cpus
        needed for the job and error-file is the path of the file to place any std-err/std-out in.
        """
        localID = self.handleLocalJob(jobNode)
        if localID:
            return localID

        mesos_resources = {
            "memory": jobNode.memory,
            "cores": jobNode.cores,
            "disk": jobNode.disk,
            "preemptable": jobNode.preemptable
        }
        self.checkResourceRequest(
            memory=mesos_resources["memory"],
            cores=mesos_resources["cores"],
            disk=mesos_resources["disk"]
        )

        jobID = self.getNextJobID()
        environment = self.environment.copy()
        if job_environment:
            environment.update(job_environment)

        job = ToilJob(jobID=jobID,
                      name=str(jobNode),
                      resources=MesosShape(wallTime=0, **mesos_resources),
                      command=jobNode.command,
                      userScript=self.userScript,
                      environment=environment,
                      workerCleanupInfo=self.workerCleanupInfo)
        jobType = job.resources
        log.debug("Queueing the job command: %s with job id: %s ...", jobNode.command, str(jobID))

        # TODO: round all elements of resources

        self.taskResources[jobID] = job.resources
        self.jobQueues.insertJob(job, jobType)
        log.debug("... queued")
        return jobID

    def killBatchJobs(self, jobIDs):

        # Some jobs may be local. Kill them first.
        self.killLocalJobs(jobIDs)

        # The driver thread does the actual work of killing the remote jobs.
        # We have to give it instructions, and block until the jobs are killed.
        assert self.driver is not None

        # This is the set of jobs that this invocation has asked to be killed,
        # but which haven't been killed yet.
        localSet = set()

        for jobID in jobIDs:
            # Queue the job up to be killed
            self.killJobIds.add(jobID)
            localSet.add(jobID)
            # Record that we meant to kill it, in case it finishes up by itself.
            self.intendedKill.add(jobID)

            if jobID in self.getIssuedBatchJobIDs():
                # Since the job has been issued, we have to kill it
                taskId = addict.Dict()
                taskId.value = str(jobID)
                log.debug("Kill issued job %s" % str(jobID))
                self.driver.killTask(taskId)
            else:
                # This job was never issued. Maybe it is a local job.
                # We don't have to kill it.
                log.debug("Skip non-issued job %s" % str(jobID))
                self.killJobIds.remove(jobID)
                localSet.remove(jobID)
        # Now localSet just has the non-local/issued jobs that we asked to kill
        while localSet:
            # Wait until they are all dead
            intersection = localSet.intersection(self.killedJobIds)
            if intersection:
                localSet -= intersection
                # When jobs are killed that we asked for, clear them out of
                # killedJobIds where the other thread put them
                self.killedJobIds -= intersection
            else:
                time.sleep(1)
        # Now all the jobs we asked to kill are dead. We know they are no
        # longer running, because that update happens before their IDs go into
        # killedJobIds. So we can safely return.

    def getIssuedBatchJobIDs(self):
        jobIds = set(self.jobQueues.jobIDs())
        jobIds.update(list(self.runningJobMap.keys()))
        return list(jobIds) + list(self.getIssuedLocalJobIDs())

    def getRunningBatchJobIDs(self):
        currentTime = dict()
        for jobID, data in list(self.runningJobMap.items()):
            currentTime[jobID] = time.time() - data.startTime
        currentTime.update(self.getRunningLocalJobIDs())
        return currentTime

    def getUpdatedBatchJob(self, maxWait):
        local_tuple = self.getUpdatedLocalJob(0)
        if local_tuple:
            return local_tuple
        while True:
            try:
                item = self.updatedJobsQueue.get(timeout=maxWait)
            except Empty:
                return None
            try:
                self.intendedKill.remove(item.jobID)
            except KeyError:
                log.debug('Job %s ended with status %i, took %s seconds.', item.jobID, item.exitStatus,
                          '???' if item.wallTime is None else str(item.wallTime))
                return item
            else:
                log.debug('Job %s ended naturally before it could be killed.', item.jobID)

    def nodeInUse(self, nodeIP: str) -> bool:
        return nodeIP in self.hostToJobIDs

    def getWaitDuration(self):
        """
        Gets the period of time to wait (floating point, in seconds) between checking for
        missing/overlong jobs.
        """
        return 1

    def _buildExecutor(self):
        """
        Creates and returns an ExecutorInfo-shaped object representing our executor implementation.
        """
        # The executor program is installed as a setuptools entry point by setup.py
        info = addict.Dict()
        info.name = "toil"
        info.command.value = resolveEntryPoint('_toil_mesos_executor')
        info.executor_id.value = "toil-%i" % os.getpid()
        info.source = pwd.getpwuid(os.getuid()).pw_name
        return info

    def _startDriver(self):
        """
        The Mesos driver thread which handles the scheduler's communication with the Mesos master
        """
        framework = addict.Dict()
        framework.user = getpass.getuser()  # We must determine the user name ourselves with pymesos
        framework.name = "toil"
        framework.principal = framework.name
        # Make the driver which implements most of the scheduler logic and calls back to us for the user-defined parts.
        # Make sure it will call us with nice namespace-y addicts
        self.driver = MesosSchedulerDriver(self, framework,
                                           self._resolveAddress(self.mesos_endpoint),
                                           use_addict=True, implicit_acknowledgements=True)
        self.driver.start()

    @staticmethod
    def _resolveAddress(address):
        """
        Resolves the host in the given string. The input is of the form host[:port]. This method
        is idempotent, i.e. the host may already be a dotted IP address.

        >>> # noinspection PyProtectedMember
        >>> f=MesosBatchSystem._resolveAddress
        >>> f('localhost')
        '127.0.0.1'
        >>> f('127.0.0.1')
        '127.0.0.1'
        >>> f('localhost:123')
        '127.0.0.1:123'
        >>> f('127.0.0.1:123')
        '127.0.0.1:123'
        """
        address = address.split(':')
        assert len(address) in (1, 2)
        address[0] = socket.gethostbyname(address[0])
        return ':'.join(address)

    def shutdown(self) -> None:
        self.shutdownLocal()
        log.debug("Stopping Mesos driver")
        self.driver.stop()
        log.debug("Joining Mesos driver")
        driver_result = self.driver.join()
        log.debug("Joined Mesos driver")
        if driver_result is not None and driver_result != 'DRIVER_STOPPED':
            # TODO: The docs say join should return a code, but it keeps returning
            # None when apparently successful. So tolerate that here too.
            raise RuntimeError("Mesos driver failed with %s" % driver_result)

    def registered(self, driver, frameworkId, masterInfo):
        """
        Invoked when the scheduler successfully registers with a Mesos master
        """
        log.debug("Registered with framework ID %s", frameworkId.value)
        # Save the framework ID
        self.frameworkId = frameworkId.value

    def _declineAllOffers(self, driver, offers):
        for offer in offers:
            driver.declineOffer(offer.id)

    def _parseOffer(self, offer):
        cores = 0
        memory = 0
        disk = 0
        preemptable = None
        for attribute in offer.attributes:
            if attribute.name == 'preemptable':
                assert preemptable is None, "Attribute 'preemptable' occurs more than once."
                preemptable = strict_bool(attribute.text.value)
        if preemptable is None:
            log.debug('Agent not marked as either preemptable or not. Assuming non-preemptable.')
            preemptable = False
        for resource in offer.resources:
            if resource.name == "cpus":
                cores += resource.scalar.value
            elif resource.name == "mem":
                memory += resource.scalar.value
            elif resource.name == "disk":
                disk += resource.scalar.value
        return cores, memory, disk, preemptable

    def _prepareToRun(self, jobType, offer):
        # Get the first element to ensure FIFO
        job = self.jobQueues.nextJobOfType(jobType)
        task = self._newMesosTask(job, offer)
        return task

    def _updateStateToRunning(self, offer, runnableTasks):
        for task in runnableTasks:
            resourceKey = int(task.task_id.value)
            resources = self.taskResources[resourceKey]
            agentIP = socket.gethostbyname(offer.hostname)
            try:
                self.hostToJobIDs[agentIP].append(resourceKey)
            except KeyError:
                self.hostToJobIDs[agentIP] = [resourceKey]

            self.runningJobMap[int(task.task_id.value)] = TaskData(startTime=time.time(),
                                                                   agentID=offer.agent_id.value,
                                                                   agentIP=agentIP,
                                                                   executorID=task.executor.executor_id.value,
                                                                   cores=resources.cores,
                                                                   memory=resources.memory)
            del self.taskResources[resourceKey]
            log.debug('Launched Mesos task %s.', task.task_id.value)

    def resourceOffers(self, driver, offers):
        """
        Invoked when resources have been offered to this framework.
        """
        self._trackOfferedNodes(offers)

        jobTypes = self.jobQueues.sortedTypes

        if not jobTypes:
            # Without jobs, we can get stuck with no jobs and no new offers until we decline it.
            self._declineAllOffers(driver, offers)
            return

        unableToRun = True
        # Right now, gives priority to largest jobs
        for offer in offers:
            if offer.hostname in self.ignoredNodes:
                driver.declineOffer(offer.id)
                continue
            runnableTasks = []
            # TODO: In an offer, can there ever be more than one resource with the same name?
            offerCores, offerMemory, offerDisk, offerPreemptable = self._parseOffer(offer)
            log.debug('Got offer %s for a %spreemptable agent with %.2f MiB memory, %.2f core(s) '
                      'and %.2f MiB of disk.', offer.id.value, '' if offerPreemptable else 'non-',
                      offerMemory, offerCores, offerDisk)
            remainingCores = offerCores
            remainingMemory = offerMemory
            remainingDisk = offerDisk

            for jobType in jobTypes:
                runnableTasksOfType = []
                # Because we are not removing from the list until outside of the while loop, we
                # must decrement the number of jobs left to run ourselves to avoid an infinite
                # loop.
                nextToLaunchIndex = 0
                # Toil specifies disk and memory in bytes but Mesos uses MiB
                while ( not self.jobQueues.typeEmpty(jobType)
                       # On a non-preemptable node we can run any job, on a preemptable node we
                       # can only run preemptable jobs:
                       and (not offerPreemptable or jobType.preemptable)
                       and remainingCores >= jobType.cores
                       and remainingDisk >= b_to_mib(jobType.disk)
                       and remainingMemory >= b_to_mib(jobType.memory)):
                    task = self._prepareToRun(jobType, offer)
                    # TODO: this used to be a conditional but Hannes wanted it changed to an assert
                    # TODO: ... so we can understand why it exists.
                    assert int(task.task_id.value) not in self.runningJobMap
                    runnableTasksOfType.append(task)
                    log.debug("Preparing to launch Mesos task %s with %.2f cores, %.2f MiB memory, and %.2f MiB disk using offer %s ...",
                              task.task_id.value, jobType.cores, b_to_mib(jobType.memory), b_to_mib(jobType.disk), offer.id.value)
                    remainingCores -= jobType.cores
                    remainingMemory -= b_to_mib(jobType.memory)
                    remainingDisk -= b_to_mib(jobType.disk)
                    nextToLaunchIndex += 1
                if not self.jobQueues.typeEmpty(jobType):
                    # report that remaining jobs cannot be run with the current resourcesq:
                    log.debug('Offer %(offer)s not suitable to run the tasks with requirements '
                              '%(requirements)r. Mesos offered %(memory)s memory, %(cores)s cores '
                              'and %(disk)s of disk on a %(non)spreemptable agent.',
                              dict(offer=offer.id.value,
                                   requirements=jobType.__dict__,
                                   non='' if offerPreemptable else 'non-',
                                   memory=mib_to_b(offerMemory),
                                   cores=offerCores,
                                   disk=mib_to_b(offerDisk)))
                runnableTasks.extend(runnableTasksOfType)
            # Launch all runnable tasks together so we only call launchTasks once per offer
            if runnableTasks:
                unableToRun = False
                driver.launchTasks(offer.id, runnableTasks)
                self._updateStateToRunning(offer, runnableTasks)
            else:
                log.debug('Although there are queued jobs, none of them could be run with offer %s '
                          'extended to the framework.', offer.id)
                driver.declineOffer(offer.id)

        if unableToRun and time.time() > (self.lastTimeOfferLogged + self.logPeriod):
            self.lastTimeOfferLogged = time.time()
            log.debug('Although there are queued jobs, none of them were able to run in '
                     'any of the offers extended to the framework. There are currently '
                     '%i jobs running. Enable debug level logging to see more details about '
                     'job types and offers received.', len(self.runningJobMap))

    def _trackOfferedNodes(self, offers):
        for offer in offers:
            # All AgentID messages are required to have a value according to the Mesos Protobuf file.
            assert 'value' in offer.agent_id
            try:
                nodeAddress = socket.gethostbyname(offer.hostname)
            except:
                log.debug("Failed to resolve hostname %s" % offer.hostname)
                raise
            self._registerNode(nodeAddress, offer.agent_id.value)
            preemptable = False
            for attribute in offer.attributes:
                if attribute.name == 'preemptable':
                    preemptable = strict_bool(attribute.text.value)
            if preemptable:
                try:
                    self.nonPreemptableNodes.remove(offer.agent_id.value)
                except KeyError:
                    pass
            else:
                self.nonPreemptableNodes.add(offer.agent_id.value)

    def _filterOfferedNodes(self, offers):
        if not self.nodeFilter:
            return offers
        executorInfoOrNone = [self.executors.get(socket.gethostbyname(offer.hostname)) for offer in offers]
        executorInfos = [_f for _f in executorInfoOrNone if _f]
        executorsToConsider = list(filter(self.nodeFilter[0], executorInfos))
        ipsToConsider = {ex.nodeAddress for ex in executorsToConsider}
        return [offer for offer in offers if socket.gethostbyname(offer.hostname) in ipsToConsider]

    def _newMesosTask(self, job, offer):
        """
        Build the Mesos task object for a given the Toil job and Mesos offer
        """
        task = addict.Dict()
        task.task_id.value = str(job.jobID)
        task.agent_id.value = offer.agent_id.value
        task.name = job.name
        task.data = encode_data(pickle.dumps(job))
        task.executor = addict.Dict(self.executor)

        task.resources = []

        task.resources.append(addict.Dict())
        cpus = task.resources[-1]
        cpus.name = 'cpus'
        cpus.type = 'SCALAR'
        cpus.scalar.value = job.resources.cores

        task.resources.append(addict.Dict())
        disk = task.resources[-1]
        disk.name = 'disk'
        disk.type = 'SCALAR'
        if b_to_mib(job.resources.disk) > 1:
            disk.scalar.value = b_to_mib(job.resources.disk)
        else:
            log.warning("Job %s uses less disk than Mesos requires. Rounding %s up to 1 MiB.",
                        job.jobID, job.resources.disk)
            disk.scalar.value = 1

        task.resources.append(addict.Dict())
        mem = task.resources[-1]
        mem.name = 'mem'
        mem.type = 'SCALAR'
        if b_to_mib(job.resources.memory) > 1:
            mem.scalar.value = b_to_mib(job.resources.memory)
        else:
            log.warning("Job %s uses less memory than Mesos requires. Rounding %s up to 1 MiB.",
                        job.jobID, job.resources.memory)
            mem.scalar.value = 1
        return task

    def statusUpdate(self, driver, update):
        """
        Invoked when the status of a task has changed (e.g., a agent is lost and so the task is
        lost, a task finishes and an executor sends a status update saying so, etc). Note that
        returning from this callback _acknowledges_ receipt of this status update! If for
        whatever reason the scheduler aborts during this callback (or the process exits) another
        status update will be delivered (note, however, that this is currently not true if the
        agent sending the status update is lost/fails during that time).
        """
        jobID = int(update.task_id.value)
        log.debug("Job %i is in state '%s' due to reason '%s'.", jobID, update.state, update.reason)

        def jobEnded(_exitStatus, wallTime=None, exitReason=None):
            """
            Notify external observers of the job ending.
            """
            self.updatedJobsQueue.put(UpdatedBatchJobInfo(jobID=jobID, exitStatus=_exitStatus, wallTime=wallTime, exitReason=exitReason))
            agentIP = None
            try:
                agentIP = self.runningJobMap[jobID].agentIP
            except KeyError:
                log.warning("Job %i returned exit code %i but isn't tracked as running.",
                            jobID, _exitStatus)
            else:
                # Mark the job as no longer running. We MUST do this BEFORE
                # saying we killed the job, or it will be possible for another
                # thread to kill a job and then see it as running.
                del self.runningJobMap[jobID]

            try:
                self.hostToJobIDs[agentIP].remove(jobID)
            except KeyError:
                log.warning("Job %i returned exit code %i from unknown host.",
                            jobID, _exitStatus)

            try:
                self.killJobIds.remove(jobID)
            except KeyError:
                pass
            else:
                # We were asked to kill this job, so say that we have done so.
                # We do this LAST, after all status updates for the job have
                # been handled, to ensure a consistent view of the scheduler
                # state from other threads.
                self.killedJobIds.add(jobID)

        if update.state == 'TASK_FINISHED':
            # We get the running time of the job via the timestamp, which is in job-local time in seconds
            labels = update.labels.labels
            wallTime = None
            for label in labels:
                if label['key'] == 'wallTime':
                    wallTime = float(label['value'])
                    break
            assert(wallTime is not None)
            jobEnded(0, wallTime=wallTime, exitReason=BatchJobExitReason.FINISHED)
        elif update.state == 'TASK_FAILED':
            try:
                exitStatus = int(update.message)
            except ValueError:
                exitStatus = EXIT_STATUS_UNAVAILABLE_VALUE
                log.warning("Job %i failed with message '%s' due to reason '%s' on executor '%s' on agent '%s'.",
                            jobID, update.message, update.reason,
                            update.executor_id, update.agent_id)
            else:
                log.warning("Job %i failed with exit status %i and message '%s' due to reason '%s' on executor '%s' on agent '%s'.",
                            jobID, exitStatus,
                            update.message, update.reason,
                            update.executor_id, update.agent_id)

            jobEnded(exitStatus, exitReason=BatchJobExitReason.FAILED)
        elif update.state == 'TASK_LOST':
            log.warning("Job %i is lost.", jobID)
            jobEnded(EXIT_STATUS_UNAVAILABLE_VALUE, exitReason=BatchJobExitReason.LOST)
        elif update.state in ('TASK_KILLED', 'TASK_ERROR'):
            log.warning("Job %i is in unexpected state %s with message '%s' due to reason '%s'.",
                        jobID, update.state, update.message, update.reason)
            jobEnded(EXIT_STATUS_UNAVAILABLE_VALUE,
                     exitReason=(BatchJobExitReason.KILLED if update.state == 'TASK_KILLED' else BatchJobExitReason.ERROR))

        if 'limitation' in update:
            log.warning("Job limit info: %s" % update.limitation)

    def frameworkMessage(self, driver, executorId, agentId, message):
        """
        Invoked when an executor sends a message.
        """

        # Take it out of base 64 encoding from Protobuf
        message = decode_data(message).decode()

        log.debug('Got framework message from executor %s running on agent %s: %s',
                  executorId.value, agentId.value, message)
        message = ast.literal_eval(message)
        assert isinstance(message, dict)
        # Handle the mandatory fields of a message
        nodeAddress = message.pop('address')
        executor = self._registerNode(nodeAddress, agentId.value)
        # Handle optional message fields
        for k, v in message.items():
            if k == 'nodeInfo':
                assert isinstance(v, dict)
                resources = [taskData for taskData in self.runningJobMap.values()
                             if taskData.executorID == executorId.value]
                requestedCores = sum(taskData.cores for taskData in resources)
                requestedMemory = sum(taskData.memory for taskData in resources)
                executor.nodeInfo = NodeInfo(requestedCores=requestedCores, requestedMemory=requestedMemory, **v)
                self.executors[nodeAddress] = executor
            else:
                raise RuntimeError("Unknown message field '%s'." % k)

    def _registerNode(self, nodeAddress, agentId, nodePort=5051):
        """
        Called when we get communication from an agent. Remembers the
        information about the agent by address, and the agent address by agent
        ID.
        """
        executor = self.executors.get(nodeAddress)
        if executor is None or executor.agentId != agentId:
            executor = self.ExecutorInfo(nodeAddress=nodeAddress,
                                         agentId=agentId,
                                         nodeInfo=None,
                                         lastSeen=time.time())
            self.executors[nodeAddress] = executor
        else:
            executor.lastSeen = time.time()

        # Record the IP under the agent id
        self.agentsByID[agentId] = nodeAddress

        return executor

    def getNodes(self,
                 preemptable: Optional[bool] = None,
                 timeout: Optional[int] = None) -> Dict[str, NodeInfo]:
        """
        Return all nodes that match:
         - preemptable status (None includes all)
         - timeout period (seen within the last # seconds, or None for all)
        """
        nodes = dict()
        for node_ip, executor in self.executors.items():
            if preemptable is None or (preemptable == (executor.agentId not in self.nonPreemptableNodes)):
                if timeout is None or (time.time() - executor.lastSeen < timeout):
                    nodes[node_ip] = executor.nodeInfo
        return nodes

    def reregistered(self, driver, masterInfo):
        """
        Invoked when the scheduler re-registers with a newly elected Mesos master.
        """
        log.debug('Registered with new master')

    def _handleFailedExecutor(self, agentID, executorID=None):
        """
        Should be called when we find out an executor has failed.

        Gets the log from some container (since we are never handed a container
        ID) that ran on the given executor on the given agent, if the agent is
        still up, and dumps it to our log. All IDs are strings.

        If executorID is None, dumps all executors from the agent.

        Useful for debugging failing executor code.
        """

        log.warning("Handling failure of executor '%s' on agent '%s'.",
                    executorID, agentID)

        try:
            # Look up the IP. We should always know it unless we get answers
            # back without having accepted offers.
            agentAddress = self.agentsByID[agentID]

            # For now we assume the agent is always on the same port. We could
            # maybe sniff this from the URL that comes in the offer but it's
            # not guaranteed to be there.
            agentPort = 5051

            # We need the container ID to read the log, but we are never given
            # it, and I can't find a good way to list it, because the API only
            # seems to report running containers. So we dump all the available
            # files with /files/debug and look for one that looks right.
            filesQueryURL = errorLogURL = "http://%s:%d/files/debug" % \
                (agentAddress, agentPort)

            # Download all the root mount points, which are in an object from
            # mounted name to real name
            filesDict = json.loads(urlopen(filesQueryURL).read())

            log.debug('Available files: %s', repr(filesDict.keys()))

            # Generate filenames for each container pointing to where stderr should be
            stderrFilenames = []
            # And look for the actual agent logs.
            agentLogFilenames = []
            for filename in filesDict:
                if (self.frameworkId in filename and agentID in filename and
                    (executorID is None or executorID in filename)):

                    stderrFilenames.append("%s/stderr" % filename)
                elif filename.endswith("log"):
                    agentLogFilenames.append(filename)

            if len(stderrFilenames) == 0:
                log.warning("Could not find any containers in '%s'." % filesDict)

            for stderrFilename in stderrFilenames:
                try:

                    # According to
                    # http://mesos.apache.org/documentation/latest/sandbox/ we can use
                    # the web API to fetch the error log.
                    errorLogURL = "http://%s:%d/files/download?path=%s" % \
                        (agentAddress, agentPort, quote_plus(stderrFilename))

                    log.warning("Attempting to retrieve executor error log: %s", errorLogURL)

                    for line in urlopen(errorLogURL):
                        # Warn all the lines of the executor's error log
                        log.warning("Executor: %s", line.rstrip())

                except Exception as e:
                    log.warning("Could not retrieve exceutor log due to: '%s'.", e)
                    log.warning(traceback.format_exc())

            for agentLogFilename in agentLogFilenames:
                try:
                    agentLogURL = "http://%s:%d/files/download?path=%s" % \
                        (agentAddress, agentPort, quote_plus(agentLogFilename))

                    log.warning("Attempting to retrieve agent log: %s", agentLogURL)

                    for line in urlopen(agentLogURL):
                        # Warn all the lines of the agent's log
                        log.warning("Agent: %s", line.rstrip())
                except Exception as e:
                    log.warning("Could not retrieve agent log due to: '%s'.", e)
                    log.warning(traceback.format_exc())

        except Exception as e:
            log.warning("Could not retrieve logs due to: '%s'.", e)
            log.warning(traceback.format_exc())

    def executorLost(self, driver, executorId, agentId, status):
        """
        Invoked when an executor has exited/terminated abnormally.
        """

        failedId = executorId.get('value', None)

        log.warning("Executor '%s' reported lost with status '%s'.", failedId, status)

        self._handleFailedExecutor(agentId.value, failedId)

    @classmethod
    def get_default_mesos_endpoint(cls) -> str:
        """
        Get the default IP/hostname and port that we will look for Mesos at.
        """
        return f'{get_public_ip()}:5050'

    @classmethod
    def add_options(cls, parser: Union[ArgumentParser, _ArgumentGroup]) -> None:
        parser.add_argument("--mesosEndpoint", "--mesosMaster", dest="mesos_endpoint", default=cls.get_default_mesos_endpoint(),
                            help="The host and port of the Mesos master separated by colon.  (default: %(default)s)")

    @classmethod
    def setOptions(cls, setOption):
        setOption("mesos_endpoint", None, None, cls.get_default_mesos_endpoint(), old_names=["mesosMasterAddress"])
コード例 #16
0
class ProcScheduler(Scheduler):

    def __init__(self):
        self.framework_id = None
        self.framework = self._init_framework()
        self.executor = None
        self.master = str(CONFIG.get('master', os.environ['MESOS_MASTER']))
        self.driver = MesosSchedulerDriver(self, self.framework, self.master)
        self.procs_pending = {}
        self.procs_launched = {}
        self.slave_to_proc = {}
        self._lock = RLock()

    def _init_framework(self):
        framework = mesos_pb2.FrameworkInfo()
        framework.user = getpass.getuser()
        framework.name = repr(self)
        framework.hostname = socket.gethostname()
        return framework

    def _init_executor(self):
        executor = mesos_pb2.ExecutorInfo()
        executor.executor_id.value = 'default'
        executor.command.value = '%s -m %s.executor' % (
            sys.executable, __package__)

        mem = executor.resources.add()
        mem.name = 'mem'
        mem.type = mesos_pb2.Value.SCALAR
        mem.scalar.value = MIN_MEMORY

        cpus = executor.resources.add()
        cpus.name = 'cpus'
        cpus.type = mesos_pb2.Value.SCALAR
        cpus.scalar.value = MIN_CPUS

        if 'PYTHONPATH' in os.environ:
            var = executor.command.environment.variables.add()
            var.name = 'PYTHONPATH'
            var.value = os.environ['PYTHONPATH']

        executor.framework_id.value = str(self.framework_id.value)
        return executor

    def _init_task(self, proc, offer):
        task = mesos_pb2.TaskInfo()
        task.task_id.value = str(proc.id)
        task.slave_id.value = offer.slave_id.value
        task.name = repr(proc)
        task.executor.MergeFrom(self.executor)
        task.data = pickle.dumps(proc.params)

        cpus = task.resources.add()
        cpus.name = 'cpus'
        cpus.type = mesos_pb2.Value.SCALAR
        cpus.scalar.value = proc.cpus

        mem = task.resources.add()
        mem.name = 'mem'
        mem.type = mesos_pb2.Value.SCALAR
        mem.scalar.value = proc.mem

        return task

    def _filters(self, seconds):
        f = mesos_pb2.Filters()
        f.refuse_seconds = seconds
        return f

    def __repr__(self):
        return "%s[%s]: %s" % (
            self.__class__.__name__,
            os.getpid(), ' '.join(sys.argv))

    def registered(self, driver, framework_id, master_info):
        with self._lock:
            logger.info('Framework registered with id=%s, master=%s' % (
                framework_id, master_info))
            self.framework_id = framework_id
            self.executor = self._init_executor()

    def resourceOffers(self, driver, offers):
        def get_resources(offer):
            cpus, mem = 0.0, 0.0
            for r in offer.resources:
                if r.name == 'cpus':
                    cpus = float(r.scalar.value)
                elif r.name == 'mem':
                    mem = float(r.scalar.value)
            return cpus, mem

        with self._lock:
            random.shuffle(offers)
            for offer in offers:
                if not self.procs_pending:
                    logger.debug('Reject offers forever for no pending procs, '
                                 'offers=%s' % (offers, ))
                    driver.launchTasks(offer.id, [], self._filters(FOREVER))
                    continue

                cpus, mem = get_resources(offer)
                tasks = []
                for proc in self.procs_pending.values():
                    if cpus >= proc.cpus and mem >= proc.mem:
                        tasks.append(self._init_task(proc, offer))
                        del self.procs_pending[proc.id]
                        self.procs_launched[proc.id] = proc
                        cpus -= proc.cpus
                        mem -= proc.mem

                seconds = 5 + random.random() * 5
                driver.launchTasks(offer.id, tasks, self._filters(seconds))
                if tasks:
                    logger.info('Accept offer for procs, offer=%s, '
                                'procs=%s, filter_time=%s' % (
                                    offer,
                                    [int(t.task_id.value) for t in tasks],
                                    seconds))
                else:
                    logger.info('Retry offer for procs later, offer=%s, '
                                'filter_time=%s' % (
                                    offer, seconds))

    def _call_finished(self, proc_id, success, message, data, slave_id=None):
        with self._lock:
            proc = self.procs_launched.pop(proc_id)
            if slave_id is not None:
                if slave_id in self.slave_to_proc:
                    self.slave_to_proc[slave_id].remove(proc_id)
            else:
                for slave_id, procs in self.slave_to_proc.iteritems():
                    if proc_id in procs:
                        procs.remove(proc_id)

            proc._finished(success, message, data)

    def statusUpdate(self, driver, update):
        with self._lock:
            proc_id = int(update.task_id.value)
            logger.info('Status update for proc, id=%s, state=%s' % (
                proc_id, update.state))
            if update.state == mesos_pb2.TASK_RUNNING:
                if update.slave_id.value in self.slave_to_proc:
                    self.slave_to_proc[update.slave_id.value].add(proc_id)
                else:
                    self.slave_to_proc[update.slave_id.value] = set([proc_id])

                proc = self.procs_launched[proc_id]
                proc._started()

            elif update.state >= mesos_pb2.TASK_FINISHED:
                slave_id = update.slave_id.value
                success = (update.state == mesos_pb2.TASK_FINISHED)
                message = update.message
                data = update.data and pickle.loads(update.data)
                self._call_finished(proc_id, success, message, data, slave_id)
                driver.reviveOffers()

    def offerRescinded(self, driver, offer_id):
        with self._lock:
            if self.procs_pending:
                logger.info('Revive offers for pending procs')
                driver.reviveOffers()

    def slaveLost(self, driver, slave_id):
        with self._lock:
            for proc_id in self.slave_to_proc.pop(slave_id, []):
                self._call_finished(
                    proc_id, False, 'Slave lost', None, slave_id)

    def error(self, driver, message):
        with self._lock:
            for proc in self.procs_pending.values():
                self._call_finished(proc.id, False, 'Stopped', None)

            for proc in self.procs_launched.values():
                self._call_finished(proc.id, False, 'Stopped', None)

        self.stop()

    def start(self):
        self.driver.start()

    def stop(self):
        assert not self.driver.aborted
        self.driver.stop()

    def submit(self, proc):
        if self.driver.aborted:
            raise RuntimeError('driver already aborted')

        with self._lock:
            if proc.id not in self.procs_pending:
                logger.info('Try submit proc, id=%s', (proc.id,))
                self.procs_pending[proc.id] = proc
                if len(self.procs_pending) == 1:
                    logger.info('Revive offers for pending procs')
                    self.driver.reviveOffers()
            else:
                raise ValueError('Proc with same id already submitted')

    def cancel(self, proc):
        if self.driver.aborted:
            raise RuntimeError('driver already aborted')

        with self._lock:
            if proc.id in self.procs_pending:
                del self.procs_pending[proc.id]
            elif proc.id in self.procs_launched:
                del self.procs_launched[proc.id]
                self.driver.killTask(mesos_pb2.TaskID(value=str(proc.id)))

            for slave_id, procs in self.slave_to_proc.items():
                procs.pop(proc.id)
                if not procs:
                    del self.slave_to_proc[slave_id]

    def send_data(self, pid, type, data):
        if self.driver.aborted:
            raise RuntimeError('driver already aborted')

        msg = pickle.dumps((pid, type, data))
        for slave_id, procs in self.slave_to_proc.iteritems():
            if pid in procs:
                self.driver.sendFrameworkMessage(
                    self.executor.executor_id,
                    mesos_pb2.SlaveID(value=slave_id),
                    msg)
                return

        raise RuntimeError('Cannot find slave for pid %s' % (pid,))
コード例 #17
0
ファイル: scheduler.py プロジェクト: weiqiangzheng/dpark
    signal.signal(signal.SIGHUP, handler)
    signal.signal(signal.SIGABRT, handler)
    signal.signal(signal.SIGQUIT, handler)

    try:
        from rfoo.utils import rconsole
        rconsole.spawn_server(locals(), 0)
    except ImportError:
        pass

    try:
        driver.start()
        sched.run(driver)
    except KeyboardInterrupt:
        logger.warning('stopped by KeyboardInterrupt')
        sched.stop(EXIT_KEYBORAD)
    except Exception as e:
        import traceback
        logger.warning('catch unexpected Exception, exit now. %s',
                       traceback.format_exc())
        sched.stop(EXIT_EXCEPTION)
    finally:
        # sched.lock may be in WRONG status.
        # if any thread of sched may use lock or call driver, join it first
        driver.stop(False)
        driver.join()
        # mesos resourses are released, and no racer for lock any more
        sched.cleanup()
        ctx.term()
        sys.exit(sched.status)
コード例 #18
0
ファイル: scheduler.py プロジェクト: pandasasa/tfmesos
class TFMesosScheduler(Scheduler):

    def __init__(self, task_spec, master=None, name=None, quiet=False,
                 volumes={}):
        self.started = False
        self.master = master or os.environ['MESOS_MASTER']
        self.name = name or '[tensorflow] %s %s' % (
            os.path.abspath(sys.argv[0]), ' '.join(sys.argv[1:]))
        self.task_spec = task_spec
        self.tasks = []
        for job in task_spec:
            for task_index in xrange(job.num):
                mesos_task_id = len(self.tasks)
                self.tasks.append(
                    Task(
                        mesos_task_id,
                        job.name,
                        task_index,
                        cpus=job.cpus,
                        mem=job.mem,
                        volumes=volumes,
                    )
                )
        if not quiet:
            global logger
            setup_logger(logger)

    def resourceOffers(self, driver, offers):
        '''
        Offer resources and launch tasks
        '''

        for offer in offers:
            if all(task.offered for task in self.tasks):
                driver.declineOffer(offer.id,
                                    mesos_pb2.Filters(refuse_seconds=FOREVER))
                continue

            offered_cpus = offered_mem = 0.0
            offered_tasks = []

            for resource in offer.resources:
                if resource.name == "cpus":
                    offered_cpus = resource.scalar.value
                elif resource.name == "mem":
                    offered_mem = resource.scalar.value

            for task in self.tasks:
                if task.offered:
                    continue

                if not (task.cpus <= offered_cpus and
                        task.mem <= offered_mem):

                    continue

                offered_cpus -= task.cpus
                offered_mem -= task.mem
                task.offered = True
                offered_tasks.append(task.to_task_info(offer, self.addr))

            driver.launchTasks(offer.id, offered_tasks, mesos_pb2.Filters())

    def _start_tf_cluster(self):
        cluster_def = {}

        targets = {}
        for task in self.tasks:
            target_name = '/job:%s/task:%s' % (task.job_name, task.task_index)
            grpc_addr = 'grpc://%s' % task.addr
            targets[target_name] = grpc_addr
            cluster_def.setdefault(task.job_name, []).append(task.addr)

        for task in self.tasks:
            response = {
                "job_name": task.job_name,
                "task_index": task.task_index,
                "cpus": task.cpus,
                "mem": task.mem,
                "cluster_def": cluster_def,
            }
            send(task.connection, response)
            assert recv(task.connection) == "ok"
            logger.info(
                "Device /job:%s/task:%s activated @ grpc://%s " % (
                    task.job_name,
                    task.task_index,
                    task.addr
                )
            )
            task.connection.close()
        return targets

    def start(self):

        def readable(fd):
            return bool(select.select([fd], [], [], 0.1)[0])

        lfd = socket.socket()
        try:
            lfd.bind(('', 0))
            self.addr = '%s:%s' % (socket.gethostname(), lfd.getsockname()[1])
            lfd.listen(10)
            framework = mesos_pb2.FrameworkInfo()
            framework.user = getpass.getuser()
            framework.name = self.name
            framework.hostname = socket.gethostname()
            self.driver = MesosSchedulerDriver(self, framework, self.master)
            self.driver.start()
            while any((not task.initalized for task in self.tasks)):
                if readable(lfd):
                    c, _ = lfd.accept()
                    if readable(c):
                        mesos_task_id, addr = recv(c)
                        assert isinstance(mesos_task_id, int)
                        task = self.tasks[mesos_task_id]
                        task.addr = addr
                        task.connection = c
                        task.initalized = True
                    else:
                        c.close()
            return self._start_tf_cluster()
        except Exception:
            self.stop()
            raise
        finally:
            lfd.close()

    def registered(self, driver, framework_id, master_info):
        logger.info(
            "Tensorflow cluster registered. "
            "( http://%s:%s/#/frameworks/%s )" % (
                master_info.hostname, master_info.port, framework_id.value
            )
        )

    def statusUpdate(self, driver, update):
        mesos_task_id = int(update.task_id.value)
        if update.state != mesos_pb2.TASK_RUNNING:
            task = self.tasks[mesos_task_id]
            if self.started:
                logger.error("Task failed: %s" % task)
                _raise(RuntimeError('Task %s failed!' % id))
            else:
                logger.warn("Task failed: %s" % task)
                task.connection.close()
                driver.reviveOffers()

    def slaveLost(self, driver, slaveId):
        if self.started:
            logger.error("Slave %s lost:" % slaveId.value)
            _raise(RuntimeError('Slave %s lost' % slaveId))

    def executorLost(self, driver, executorId, slaveId, status):
        if self.started:
            logger.error("Executor %s lost:" % executorId.value)
            _raise(RuntimeError('Executor %s@%s lost' % (executorId, slaveId)))

    def error(self, driver, message):
        logger.error("Mesos error: %s" % message)
        _raise(RuntimeError('Error ' + message))

    def stop(self):
        logger.debug("exit")

        if hasattr(self, "tasks"):
            for task in getattr(self, "tasks", []):
                task.connection.close()
            del self.tasks

        if hasattr(self, "driver"):
            self.driver.stop()
            del self.driver