示例#1
0
 def test_task_host_manager(self):
     manager = TaskHostManager()
     manager.register_host('fake1', purge_elapsed=1)
     manager.register_host('fake2', purge_elapsed=1)
     manager.register_host('fake3', purge_elapsed=1)
     host_offers = {'fake1': (1, None), 'fake2': (2, None),
                    'fake3': (3, None)}
     manager.task_failed(1, 'fake2', OtherFailure('Mock failed'))
     assert manager.offer_choice(1, host_offers, ['fake3'])[0] == 1
     time.sleep(1)
     manager.task_failed(1, 'fake1', OtherFailure('Mock failed'))
     assert manager.offer_choice(1, host_offers, [])[0] == 3
     assert manager.offer_choice(1, host_offers, ['fake3'])[0] is None
     manager.task_succeed(2, 'fake2', TaskEndReason.success)
     assert manager.offer_choice(1, host_offers, ['fake3'])[0] is None
     time.sleep(1)
     assert manager.offer_choice(1, host_offers, ['fake3'])[0] == 2
示例#2
0
 def test_task_host_manager(self):
     manager = TaskHostManager()
     manager.register_host('fake1', purge_elapsed=1)
     manager.register_host('fake2', purge_elapsed=1)
     manager.register_host('fake3', purge_elapsed=1)
     host_offers = {'fake1': (1, None), 'fake2': (2, None),
                    'fake3': (3, None)}
     manager.task_failed(1, 'fake2', OtherFailure('Mock failed'))
     assert manager.offer_choice(1, host_offers, ['fake3'])[0] == 1
     time.sleep(1)
     manager.task_failed(1, 'fake1', OtherFailure('Mock failed'))
     assert manager.offer_choice(1, host_offers, [])[0] == 3
     assert manager.offer_choice(1, host_offers, ['fake3'])[0] is None
     manager.task_succeed(2, 'fake2', TaskEndReason.success)
     assert manager.offer_choice(1, host_offers, ['fake3'])[0] is None
     time.sleep(1)
     assert manager.offer_choice(1, host_offers, ['fake3'])[0] == 2
示例#3
0
文件: schedule.py 项目: posens/dpark
class MesosScheduler(DAGScheduler):
    def __init__(self, master, options, webui_url=None):
        DAGScheduler.__init__(self)
        self.master = master
        self.cpus = options.cpus
        self.mem = options.mem
        self.task_per_node = options.parallel or 8
        self.group = options.group
        self.logLevel = options.logLevel
        self.options = options
        self.role = options.role
        self.color = options.color
        self.webui_url = webui_url
        self.started = False
        self.last_finish_time = 0
        self.isRegistered = False
        self.executor = None
        self.driver = None
        self.out_logger = LogReceiver(sys.stdout)
        self.err_logger = LogReceiver(sys.stderr)
        self.lock = threading.RLock()
        self.task_host_manager = TaskHostManager()
        self.init_tasksets()

    def init_tasksets(self):
        self.active_tasksets = {}
        self.ttid_to_agent_id = {}
        self.agent_id_to_ttids = {}

    def clear(self):
        DAGScheduler.clear(self)
        self.init_tasksets()

    def processHeartBeat(self):
        # no need in dpark now, just for compatibility with pymesos
        pass

    def start(self):
        self.out_logger.start()
        self.err_logger.start()

    def start_driver(self):
        name = '[dpark] ' + \
               os.path.abspath(sys.argv[0]) + ' ' + ' '.join(sys.argv[1:])
        if len(name) > 256:
            name = name[:256] + '...'
        framework = Dict()
        framework.user = getuser()
        if framework.user == 'root':
            raise Exception('dpark is not allowed to run as \'root\'')
        framework.name = name
        if self.role:
            framework.role = self.role
        framework.hostname = socket.gethostname()
        if self.webui_url:
            framework.webui_url = self.webui_url

        self.driver = MesosSchedulerDriver(self,
                                           framework,
                                           self.master,
                                           use_addict=True)
        self.driver.start()
        logger.debug('Mesos Scheudler driver started')

        self.started = True
        self.last_finish_time = time.time()

        def check():
            while self.started:
                with self.lock:
                    now = time.time()
                    if (not self.active_tasksets
                            and now - self.last_finish_time > MAX_IDLE_TIME):
                        logger.info(
                            'stop mesos scheduler after %d seconds idle',
                            now - self.last_finish_time)
                        self.stop()
                        break

                    for taskset in self.active_tasksets.values():
                        if taskset.check_task_timeout():
                            self.requestMoreResources()
                time.sleep(1)

        spawn(check)

    @safe
    def registered(self, driver, frameworkId, masterInfo):
        self.isRegistered = True
        self.frameworkId = frameworkId.value
        logger.debug('connect to master %s:%s, registered as %s',
                     masterInfo.hostname, masterInfo.port, frameworkId.value)
        self.executor = self.getExecutorInfo(str(frameworkId.value))
        from dpark.utils.log import add_loghub
        self.loghub_dir = add_loghub(self.frameworkId)

    @safe
    def reregistered(self, driver, masterInfo):
        logger.warning('re-connect to mesos master %s:%s', masterInfo.hostname,
                       masterInfo.port)

    @safe
    def disconnected(self, driver):
        logger.debug('framework is disconnected')

    def _get_container_image(self):
        return self.options.image

    @safe
    def getExecutorInfo(self, framework_id):
        info = Dict()
        info.framework_id.value = framework_id
        info.command.value = '%s %s' % (
            sys.executable,
            os.path.abspath(
                os.path.join(os.path.dirname(__file__), 'executor.py')))
        info.executor_id.value = env.get('DPARK_ID', 'default')
        info.command.environment.variables = variables = []

        v = Dict()
        variables.append(v)
        v.name = 'UID'
        v.value = str(os.getuid())

        v = Dict()
        variables.append(v)
        v.name = 'GID'
        v.value = str(os.getgid())

        container_image = self._get_container_image()
        if container_image:
            info.container.type = 'DOCKER'
            info.container.docker.image = container_image
            info.container.docker.parameters = parameters = []
            p = Dict()
            p.key = 'memory-swap'
            p.value = '-1'
            parameters.append(p)

            info.container.volumes = volumes = []
            for path in ['/etc/passwd', '/etc/group']:
                v = Dict()
                volumes.append(v)
                v.host_path = v.container_path = path
                v.mode = 'RO'

            for path in conf.MOOSEFS_MOUNT_POINTS:
                v = Dict()
                volumes.append(v)
                v.host_path = v.container_path = path
                v.mode = 'RW'

            for path in conf.DPARK_WORK_DIR.split(','):
                v = Dict()
                volumes.append(v)
                v.host_path = v.container_path = path
                v.mode = 'RW'

            def _mount_volume(_volumes, _host_path, _container_path, _mode):
                _v = Dict()
                _volumes.append(_v)
                _v.container_path = _container_path
                _v.mode = _mode
                if _host_path:
                    _v.host_path = _host_path

            if self.options.volumes:
                for volume in self.options.volumes.split(','):
                    fields = volume.split(':')
                    if len(fields) == 3:
                        host_path, container_path, mode = fields
                        mode = mode.upper()
                        assert mode in ('RO', 'RW')
                    elif len(fields) == 2:
                        host_path, container_path = fields
                        mode = 'RW'
                    elif len(fields) == 1:
                        container_path, = fields
                        host_path = ''
                        mode = 'RW'
                    else:
                        raise Exception('cannot parse volume %s', volume)
                    _mount_volume(volumes, host_path, container_path, mode)

        info.resources = resources = []

        mem = Dict()
        resources.append(mem)
        mem.name = 'mem'
        mem.type = 'SCALAR'
        mem.scalar.value = EXECUTOR_MEMORY

        cpus = Dict()
        resources.append(cpus)
        cpus.name = 'cpus'
        cpus.type = 'SCALAR'
        cpus.scalar.value = EXECUTOR_CPUS

        Script = os.path.realpath(sys.argv[0])
        info.name = Script

        info.data = encode_data(
            marshal.dumps((Script, os.getcwd(), sys.path, dict(os.environ),
                           self.task_per_node, self.out_logger.addr,
                           self.err_logger.addr, self.logLevel, self.color,
                           env.environ)))
        assert len(info.data) < (50 << 20), \
            'Info data too large: %s' % (len(info.data),)
        return info

    @safe
    def submitTasks(self, tasks):
        if not tasks:
            return

        rdd = tasks[0].rdd
        assert all(t.rdd is rdd for t in tasks)

        taskset = TaskSet(self, tasks, rdd.cpus or self.cpus, rdd.mem
                          or self.mem, rdd.gpus, self.task_host_manager)
        self.active_tasksets[taskset.id] = taskset
        stage_scope = ''
        try:
            from dpark.web.ui.views.rddopgraph import StageInfo
            stage_scope = StageInfo.idToRDDNode[
                tasks[0].rdd.id].scope.call_site
        except:
            pass
        stage = self.idToStage[tasks[0].stage_id]
        stage.num_try += 1
        logger.info(
            'Got taskset %s with %d tasks for stage: %d '
            'at scope[%s] and rdd:%s', taskset.id, len(tasks),
            tasks[0].stage_id, stage_scope, tasks[0].rdd)

        need_revive = self.started
        if not self.started:
            self.start_driver()
        while not self.isRegistered:
            self.lock.release()
            time.sleep(0.01)
            self.lock.acquire()

        if need_revive:
            self.requestMoreResources()

    def requestMoreResources(self):
        logger.debug('reviveOffers')
        self.driver.reviveOffers()

    @safe
    def resourceOffers(self, driver, offers):
        rf = Dict()
        if not self.active_tasksets:
            driver.suppressOffers()
            rf.refuse_seconds = 60 * 5
            for o in offers:
                driver.declineOffer(o.id, rf)
            return

        start = time.time()
        filter_offer = []
        for o in offers:
            try:
                if conf.ban(o.hostname):
                    logger.debug("skip offer on banned node: %s", o.hostname)
                    continue
            except:
                logger.exception("bad ban() func in dpark.conf")

            group = (self.getAttribute(o.attributes, 'group') or 'None')
            if (self.group
                    or group.startswith('_')) and group not in self.group:
                driver.declineOffer(o.id,
                                    filters=Dict(refuse_seconds=0xFFFFFFFF))
                continue
            if self.task_host_manager.is_unhealthy_host(o.hostname):
                logger.warning('the host %s is unhealthy so skip it',
                               o.hostname)
                driver.declineOffer(o.id, filters=Dict(refuse_seconds=1800))
                continue
            self.task_host_manager.register_host(o.hostname)
            filter_offer.append(o)
        offers = filter_offer
        cpus = [self.getResource(o.resources, 'cpus') for o in offers]
        gpus = [self.getResource(o.resources, 'gpus') for o in offers]
        mems = [
            self.getResource(o.resources, 'mem') -
            (o.agent_id.value not in self.agent_id_to_ttids and EXECUTOR_MEMORY
             or 0) for o in offers
        ]
        # logger.debug('get %d offers (%s cpus, %s mem, %s gpus), %d tasksets',
        #             len(offers), sum(cpus), sum(mems), sum(gpus), len(self.active_tasksets))

        tasks = {}
        for taskset in self.active_tasksets.values():
            while True:
                host_offers = {}
                for i, o in enumerate(offers):
                    if self.agent_id_to_ttids.get(o.agent_id.value,
                                                  0) >= self.task_per_node:
                        logger.debug('the task limit exceeded at host %s',
                                     o.hostname)
                        continue
                    if (mems[i] < self.mem + EXECUTOR_MEMORY
                            or cpus[i] < self.cpus + EXECUTOR_CPUS):
                        continue
                    host_offers[o.hostname] = (i, o)
                assigned_list = taskset.taskOffer(host_offers, cpus, mems,
                                                  gpus)
                if not assigned_list:
                    break
                for i, o, t in assigned_list:
                    task = self.createTask(o, t)
                    tasks.setdefault(o.id.value, []).append(task)
                    logger.debug('dispatch %s into %s', t, o.hostname)
                    ttid = task.task_id.value
                    agent_id = o.agent_id.value
                    taskset.ttids.add(ttid)
                    self.ttid_to_agent_id[ttid] = agent_id
                    self.agent_id_to_ttids[
                        agent_id] = self.agent_id_to_ttids.get(agent_id, 0) + 1
                    cpus[i] -= min(cpus[i], t.cpus)
                    mems[i] -= t.mem
                    gpus[i] -= t.gpus

        used = time.time() - start
        if used > 10:
            logger.error('use too much time in resourceOffers: %.2fs', used)

        for o in offers:
            if o.id.value in tasks:
                driver.launchTasks(o.id, tasks[o.id.value])
            else:
                driver.declineOffer(o.id)

        # logger.debug('reply with %d tasks, %s cpus %s mem %s gpus left',
        #            sum(len(ts) for ts in tasks.values()),
        #             sum(cpus), sum(mems), sum(gpus))

    @safe
    def offerRescinded(self, driver, offer_id):
        logger.debug('rescinded offer: %s', offer_id)
        if self.active_tasksets:
            self.requestMoreResources()

    def getResource(self, res, name):
        for r in res:
            if r.name == name:
                return r.scalar.value
        return 0.0

    def getAttribute(self, attrs, name):
        for r in attrs:
            if r.name == name:
                return r.text.value

    def createTask(self, o, t):
        task = Dict()
        tid = t.try_id
        task.name = 'task %s' % tid
        task.task_id.value = tid
        task.agent_id.value = o.agent_id.value
        task.data = encode_data(compress(cPickle.dumps((t, tid), -1)))
        task.executor = self.executor
        if len(task.data) > 1000 * 1024:
            logger.warning('task too large: %s %d', t, len(task.data))

        assert len(task.data) < (50 << 20), \
            'Task data too large: %s' % (len(task.data),)

        resources = task.resources = []

        cpu = Dict()
        resources.append(cpu)
        cpu.name = 'cpus'
        cpu.type = 'SCALAR'
        cpu.scalar.value = t.cpus

        mem = Dict()
        resources.append(mem)
        mem.name = 'mem'
        mem.type = 'SCALAR'
        mem.scalar.value = t.mem

        cpu = Dict()
        resources.append(cpu)
        cpu.name = 'gpus'
        cpu.type = 'SCALAR'
        cpu.scalar.value = t.gpus

        return task

    @safe
    def statusUpdate(self, driver, status):
        def plot_progresses():
            if self.color:
                total = len(self.active_tasksets)
                logger.info('\x1b[2K\x1b[J\x1b[1A')
                for i, taskset_id in enumerate(self.active_tasksets):
                    if i == total - 1:
                        ending = '\x1b[%sA' % total
                    else:
                        ending = ''

                    tasksets = self.active_tasksets[taskset_id]
                    tasksets.progress(ending)

        mesos_task_id = status.task_id.value
        state = status.state
        reason = status.get('message')  # set by mesos
        data = status.get('data')

        logger.debug('status update: %s %s', mesos_task_id, state)

        ttid = TTID(mesos_task_id)

        taskset = self.active_tasksets.get(ttid.taskset_id)

        if taskset is None:
            if state == 'TASK_RUNNING':
                logger.debug('kill task %s as its taskset has gone',
                             mesos_task_id)
                self.driver.killTask(Dict(value=mesos_task_id))
            else:
                logger.debug('ignore task %s as its taskset has gone',
                             mesos_task_id)
            return

        if state == 'TASK_RUNNING':
            taskset.statusUpdate(ttid.task_id, ttid.task_try, state)
            if taskset.tasksFinished == 0:
                plot_progresses()
        else:
            if mesos_task_id not in taskset.ttids:
                logger.debug(
                    'ignore task %s as it has finished or failed, new msg: %s',
                    mesos_task_id, (state, reason))
            else:
                taskset.ttids.remove(mesos_task_id)
                if mesos_task_id in self.ttid_to_agent_id:
                    agent_id = self.ttid_to_agent_id[mesos_task_id]
                    if agent_id in self.agent_id_to_ttids:
                        self.agent_id_to_ttids[agent_id] -= 1
                    del self.ttid_to_agent_id[mesos_task_id]

                if state in ('TASK_FINISHED', 'TASK_FAILED') and data:
                    try:
                        reason, result, accUpdate, task_stats = cPickle.loads(
                            decode_data(data))
                        if result:
                            flag, data = result
                            if flag >= 2:
                                try:
                                    data = urllib.request.urlopen(data).read()
                                except IOError:
                                    # try again
                                    data = urllib.request.urlopen(data).read()
                                flag -= 2
                            data = decompress(data)
                            if flag == 0:
                                result = marshal.loads(data)
                            else:
                                result = cPickle.loads(data)
                        taskset.statusUpdate(ttid.task_id, ttid.task_try,
                                             state, reason, result, accUpdate,
                                             task_stats)
                        if state == 'TASK_FINISHED':
                            plot_progresses()
                    except Exception as e:
                        logger.warning(
                            'error when cPickle.loads(): %s, data:%s', e,
                            len(data))
                        state = 'TASK_FAILED'
                        taskset.statusUpdate(ttid.task_id, ttid.task_try,
                                             state, 'load failed: %s' % e)
                else:
                    # killed, lost
                    taskset.statusUpdate(ttid.task_id, ttid.task_try, state,
                                         reason or data)

    @safe
    def tasksetFinished(self, taskset):
        logger.debug('taskset %s finished', taskset.id)
        if taskset.id in self.active_tasksets:
            self.last_finish_time = time.time()
            for mesos_task_id in taskset.ttids:
                self.driver.killTask(Dict(value=mesos_task_id))
            del self.active_tasksets[taskset.id]

            if not self.active_tasksets:
                self.agent_id_to_ttids.clear()

    @safe
    def error(self, driver, message):
        logger.error('Mesos error message: %s', message)
        raise RuntimeError(message)

    # @safe
    def stop(self):
        if not self.started:
            return
        logger.debug('stop scheduler')
        self.started = False
        self.isRegistered = False
        self.driver.stop(False)
        self.driver.join()
        self.driver = None

        self.out_logger.stop()
        self.err_logger.stop()

    def defaultParallelism(self):
        return 16

    def frameworkMessage(self, driver, executor_id, agent_id, data):
        logger.warning('[agent %s] %s', agent_id.value, data)

    def executorLost(self, driver, executor_id, agent_id, status):
        logger.warning('executor at %s %s lost: %s', agent_id.value,
                       executor_id.value, status)
        self.agent_id_to_ttids.pop(agent_id.value, None)

    def slaveLost(self, driver, agent_id):
        logger.warning('agent %s lost', agent_id.value)
        self.agent_id_to_ttids.pop(agent_id.value, None)

    def killTask(self, task_id, num_try):
        tid = Dict()
        tid.value = TTID.make_ttid(task_id, num_try)
        self.driver.killTask(tid)