示例#1
0
    def setUp(self):
        self.client_logger = get_logger(name='cola_test_client',
                                        server='localhost')
        self.server_logger = get_logger(name='cola_test_server')

        self.log_server = LogRecordSocketReceiver(logger=self.server_logger)
        threading.Thread(target=self.log_server.serve_forever).start()
示例#2
0
class Test(unittest.TestCase):
    def setUp(self):
        self.client_logger = get_logger(name='cola_test_client',
                                        server='localhost')
        self.server_logger = get_logger(name='cola_test_server')

        self.log_server = LogRecordSocketReceiver(logger=self.server_logger)
        threading.Thread(target=self.log_server.serve_forever).start()

    def tearDown(self):
        self.log_server.shutdown()

    def testLog(self):
        self.client_logger.error('Sth happens here')
        self.client_logger.info('sth info here')
示例#3
0
文件: test_log.py 项目: 0pengl/cola
class Test(unittest.TestCase):


    def setUp(self):
        self.client_logger = get_logger(name='cola_test_client', server='localhost')
        self.server_logger = get_logger(name='cola_test_server')
        
        self.log_server = LogRecordSocketReceiver(logger=self.server_logger)
        threading.Thread(target=self.log_server.serve_forever).start()

    def tearDown(self):
        self.log_server.shutdown()

    def testLog(self):
        self.client_logger.error('Sth happens here')
        self.client_logger.info('sth info here')
示例#4
0
def start_log_server():
    global log_server
    global log_server_port
    
    if log_server is not None:
        return
    log_server = LogRecordSocketReceiver(logger=logger, host=get_ip(), 
                                         port=log_server_port)
    threading.Thread(target=log_server.serve_forever).start()
示例#5
0
文件: master.py 项目: laocheng/cola
 def _init_log_server(self, logger):
     self.log_server = LogRecordSocketReceiver(host=self.ctx.ip,
                                               logger=self.logger)
     self.log_t = threading.Thread(target=self.log_server.serve_forever)
     self.log_t.start()
示例#6
0
文件: master.py 项目: laocheng/cola
class Master(object):
    def __init__(self, ctx):
        self.ctx = ctx
        self.rpc_server = self.ctx.master_rpc_server
        assert self.rpc_server is not None

        self.working_dir = os.path.join(self.ctx.working_dir, 'master')
        self.zip_dir = os.path.join(self.working_dir, 'zip')
        self.job_dir = os.path.join(self.working_dir, 'jobs')
        if not os.path.exists(self.zip_dir):
            os.makedirs(self.zip_dir)
        if not os.path.exists(self.job_dir):
            os.makedirs(self.job_dir)

        self.worker_tracker = WorkerTracker()
        self.job_tracker = JobTracker()

        self.black_list = []

        self.stopped = threading.Event()

        self.logger = get_logger("cola_master")
        self._init_log_server(self.logger)

        self._register_rpc()
        self.load()
        FileTransportServer(self.rpc_server, self.zip_dir)

    def load(self):
        self.runned_job_metas = {}

        job_meta_file = os.path.join(self.working_dir,
                                     JOB_META_STATUS_FILENAME)
        if os.path.exists(job_meta_file) and \
            os.path.getsize(job_meta_file) > 0:
            try:
                with open(job_meta_file) as f:
                    self.runned_job_metas = pickle.load(f)
            except:
                pass

    def save(self):
        job_meta_file = os.path.join(self.working_dir,
                                     JOB_META_STATUS_FILENAME)
        with open(job_meta_file, 'w') as f:
            pickle.dump(self.runned_job_metas, f)

    def _register_rpc(self):
        self.rpc_server.register_function(self.run_job, 'run_job')
        self.rpc_server.register_function(self.stop_job, 'stop_job')
        self.rpc_server.register_function(self.pack_job_error,
                                          'pack_job_error')
        self.rpc_server.register_function(self.list_runnable_jobs,
                                          'runnable_jobs')
        self.rpc_server.register_function(
            lambda: self.job_tracker.running_jobs.keys(), 'running_jobs')
        self.rpc_server.register_function(self.list_workers, 'list_workers')
        self.rpc_server.register_function(self.shutdown, 'shutdown')
        self.rpc_server.register_function(self.register_heartbeat,
                                          'register_heartbeat')

    def register_heartbeat(self, worker):
        self.worker_tracker.register_worker(worker)
        return self.worker_tracker.workers.keys()

    def _init_log_server(self, logger):
        self.log_server = LogRecordSocketReceiver(host=self.ctx.ip,
                                                  logger=self.logger)
        self.log_t = threading.Thread(target=self.log_server.serve_forever)
        self.log_t.start()

    def _shutdown_log_server(self):
        if hasattr(self, 'log_server'):
            self.log_server.shutdown()
            self.log_t.join()

    def _check_workers(self):
        while not self.stopped.is_set():
            for worker, info in self.worker_tracker.workers.iteritems():
                # if loose connection
                if int(time.time()) - info.last_update \
                    > HEARTBEAT_CHECK_INTERVAL:

                    info.continous_register = 0
                    if info.status == RUNNING:
                        info.status = HANGUP
                    elif info.status == HANGUP:
                        info.status = STOPPED
                        self.black_list.append(worker)

                        for job in self.job_tracker.running_jobs:
                            self.job_tracker.remove_worker(job, worker)

                # if continously connect for more than 10 min
                elif info.continous_register >= CONTINOUS_HEARTBEAT:
                    if info.status != RUNNING:
                        info.status = RUNNING
                    if worker in self.black_list:
                        self.black_list.remove(worker)

                    for job in self.job_tracker.running_jobs:
                        if not client_call(worker, 'has_job'):
                            client_call(worker, 'prepare', job)
                            client_call(worker, 'run_job', job)
                        self.job_tracker.add_worker(job, worker)

            self.stopped.wait(HEARTBEAT_CHECK_INTERVAL)

    def _check_jobs(self):
        while not self.stopped.is_set():
            for job_master in self.job_tracker.running_jobs.values():
                if job_master.budget_server.get_status() == ALLFINISHED:
                    self.stop_job(job_master.job_name)
                    self.job_tracker.remove_job(job_master.job_name)
            self.stopped.wait(JOB_CHECK_INTERVAL)

    def _unzip(self, job_name):
        zip_file = os.path.join(self.zip_dir, job_name + '.zip')
        if os.path.exists(zip_file):
            ZipHandler.uncompress(zip_file, self.job_dir)

    def _register_runned_job(self, job_name, job_desc):
        self.runned_job_metas[job_name] = {
            'job_name': job_desc.name,
            'created': time.time()
        }

    def run(self):
        self._worker_t = threading.Thread(target=self._check_workers)
        self._worker_t.start()

        self._job_t = threading.Thread(target=self._check_jobs)
        self._job_t.start()

    def run_job(self, job_name, unzip=False, wait_for_workers=False):
        if wait_for_workers:
            while not self.stopped.is_set():
                if len(self.worker_tracker.workers) > 0:
                    break
                stopped = self.stopped.wait(3)
                if stopped:
                    return

        if unzip:
            self._unzip(job_name)

        job_path = os.path.join(self.job_dir, job_name)
        job_desc = import_job_desc(job_path)
        job_master = JobMaster(self.ctx, job_name, job_desc,
                               self.worker_tracker.workers.keys())
        job_master.init()
        self.job_tracker.register_job(job_name, job_master)
        self._register_runned_job(job_name, job_desc)

        zip_file = os.path.join(self.zip_dir, job_name + '.zip')
        for worker in job_master.workers:
            FileTransportClient(worker, zip_file).send_file()

        self.logger.debug('entering the master prepare stage, job id: %s' %
                          job_name)
        self.logger.debug('job available workers: %s' % job_master.workers)
        stage = Stage(job_master.workers, 'prepare')
        prepared_ok = stage.barrier(True, job_name)
        if not prepared_ok:
            self.logger.error("prepare for running failed")
            return

        self.logger.debug('entering the master run_job stage, job id: %s' %
                          job_name)
        stage = Stage(job_master.workers, 'run_job')
        run_ok = stage.barrier(True, job_name)
        if not run_ok:
            self.logger.error("run job failed, job id: %s" % job_name)

    def stop_job(self, job_name):
        job_master = self.job_tracker.get_job_master(job_name)
        stage = Stage(job_master.workers, 'stop_job')
        stage.barrier(True, job_name)

        stage = Stage(job_master.workers, 'clear_job')
        stage.barrier(True, job_name)

        self.job_tracker.remove_job(job_name)

        self.logger.debug('stop job: %s' % job_name)

    def pack_job_error(self, job_name):
        job_master = self.job_tracker.get_job_master(job_name)
        stage = Stage(job_master.workers, 'pack_job_error')
        stage.barrier(True, job_name)

        error_dir = os.path.join(self.working_dir, 'errors')
        if not os.path.exists(error_dir):
            os.makedirs(error_dir)
        error_filename = os.path.join(error_dir, '%s_errors.zip' % job_name)

        suffix = '%s_errors.zip' % job_name
        temp_dir = tempfile.mkdtemp()
        try:
            for name in os.listdir(self.zip_dir):
                if name.endswith(suffix):
                    shutil.move(os.path.join(self.zip_dir, name), temp_dir)
            ZipHandler.compress(error_filename, temp_dir)
        finally:
            shutil.rmtree(temp_dir)

        return error_filename

    def list_runnable_jobs(self):
        job_dirs = filter(
            lambda s: os.path.isdir(os.path.join(self.job_dir, s)),
            os.listdir(self.job_dir))

        jobs = {}
        for job_dir in job_dirs:
            desc = import_job_desc(os.path.join(self.job_dir, job_dir))
            jobs[job_dir] = desc.name
        return jobs

    def has_running_jobs(self):
        return len(self.job_tracker.running_jobs) > 0

    def list_workers(self):
        return [(worker, STATUSES[worker_info.status]) for worker, worker_info \
                in self.worker_tracker.workers.iteritems()]

    def _stop_all_jobs(self):
        for job_name in self.job_tracker.running_jobs.keys():
            self.stop_job(job_name)

    def _shutdown_all_workers(self):
        stage = Stage(self.worker_tracker.workers.keys(), 'shutdown')
        stage.barrier(True)

    def shutdown(self):
        if not hasattr(self, '_worker_t'):
            return
        if not hasattr(self, '_job_t'):
            return

        self.logger.debug('master starts to shutdown')

        self.stopped.set()
        self._stop_all_jobs()
        self._shutdown_all_workers()

        self._worker_t.join()
        self._job_t.join()

        self.save()
        self.rpc_server.shutdown()
        self.logger.debug('master shutdown finished')
        self._shutdown_log_server()
示例#7
0
文件: loader.py 项目: Ganer/cola
class MasterJobLoader(LimitionJobLoader, JobLoader):
    def __init__(self, job, data_dir, nodes, local_ip=None, client=None,
                 context=None, copies=1, force=False):
        ctx = context or job.context
        master_port = ctx.job.master_port
        if local_ip is None:
            local_ip = get_ip()
        else:
            choices_ips = get_ips()
            if local_ip not in choices_ips:
                raise ValueError('IP address must be one of (%s)' % ','.join(choices_ips))
        local = '%s:%s' % (local_ip, master_port)
        
        JobLoader.__init__(self, job, data_dir, local, 
                           context=ctx, copies=copies, force=force)
        LimitionJobLoader.__init__(self, job, context=ctx)
        
        # check
        self.check()
        
        self.nodes = nodes
        self.not_registered = self.nodes[:]
        self.not_finished = self.nodes[:]
        
        # mq
        self.mq_client = MessageQueueClient(self.nodes, copies=copies)
        
        # lock
        self.ready_lock = threading.Lock()
        self.ready_lock.acquire()
        self.finish_lock = threading.Lock()
        self.finish_lock.acquire()
        
        # logger
        self.logger = get_logger(
            name='cola_master_%s'%self.job.real_name,
            filename=os.path.join(self.root, 'job.log'),
            is_master=True)
        self.client = client
        self.client_handler = None
        if self.client is not None:
            self.client_handler = add_log_client(self.logger, self.client)
        
        self.init_rpc_server()
        self.init_rate_clear()
        self.init_logger_server(self.logger)
        
        # register rpc server
        self.rpc_server.register_function(self.client_stop, 'client_stop')
        self.rpc_server.register_function(self.ready, 'ready')
        self.rpc_server.register_function(self.worker_finish, 'worker_finish')
        self.rpc_server.register_function(self.complete, 'complete')
        self.rpc_server.register_function(self.error, 'error')
        self.rpc_server.register_function(self.get_nodes, 'get_nodes')
        self.rpc_server.register_function(self.apply, 'apply')
        self.rpc_server.register_function(self.require, 'require')
        self.rpc_server.register_function(self.stop, 'stop')
        self.rpc_server.register_function(self.add_node, 'add_node')
        self.rpc_server.register_function(self.remove_node, 'remove_node')
        
        # register signal
        signal.signal(signal.SIGINT, self.signal_handler)
        signal.signal(signal.SIGTERM, self.signal_handler)
        
    def init_logger_server(self, logger):
        self.log_server = LogRecordSocketReceiver(host=get_ip(), logger=logger)
        threading.Thread(target=self.log_server.serve_forever).start()
        
    def stop_logger_server(self):
        if hasattr(self, 'log_server'):
            self.log_server.shutdown()
            self.log_server.stop()
            
    def client_stop(self):
        if self.client_handler is not None:
            self.logger.removeHandler(self.client_handler)
                
    def check(self):
        env_legal = self.check_env(force=self.force)
        if not env_legal:
            raise JobMasterRunning('There has been a running job master.')
        
    def release_lock(self, lock):
        try:
            lock.release()
        except:
            pass
        
    def finish(self):
        self.release_lock(self.ready_lock)
        self.release_lock(self.finish_lock)
        
        LimitionJobLoader.finish(self)
        JobLoader.finish(self)
        self.stop_logger_server()
        
        try:
            for handler in self.logger.handlers:
                handler.close()
        except:
            pass
            
        if self.client is not None:
            rpc_client = '%s:%s' % (
                self.client.split(':')[0], 
                main_conf.client.port
            )
            client_call(rpc_client, 'stop', ignore=True)
            
        self.stopped = True
        
    def stop(self):
        for node in self.nodes:
            try:
                client_call(node, 'stop')
            except socket.error:
                pass
        self.finish()
        
    def signal_handler(self, signum, frame):
        self.stop()
        
    def get_nodes(self):
        return self.nodes
        
    def ready(self, node):
        if node in self.not_registered:
            self.not_registered.remove(node)
            if len(self.not_registered) == 0:
                self.ready_lock.release()
                
    def worker_finish(self, node):
        if node in self.not_finished:
            self.not_finished.remove(node)
            if len(self.not_finished) == 0:
                self.finish_lock.release()
                
    def add_node(self, node):
        for node in self.nodes:
            client_call(node, 'add_node', node)
        self.nodes.append(node)
        client_call(node, 'run')
        
    def remove_node(self, node):
        for node in self.nodes:
            client_call(node, 'remove_node', node)
        self.nodes.remove(node)
        
    def run(self):
        self.ready_lock.acquire()
        
        if not self.stopped and len(self.not_registered) == 0:
            self.mq_client.put(self.job.starts)
            for node in self.nodes:
                client_call(node, 'run')
            
        self.finish_lock.acquire()
        
        try:
            master_watcher = '%s:%s' % (get_ip(), main_conf.master.port)
            client_call(master_watcher, 'finish_job', self.job.real_name)
        except socket.error:
            pass
        
    def __enter__(self):
        return self
    
    def __exit__(self, type_, value, traceback):
        self.finish()
示例#8
0
文件: loader.py 项目: Ganer/cola
 def init_logger_server(self, logger):
     self.log_server = LogRecordSocketReceiver(host=get_ip(), logger=logger)
     threading.Thread(target=self.log_server.serve_forever).start()
示例#9
0
文件: master.py 项目: awai0707/cola
 def _init_log_server(self, logger):
     self.log_server = LogRecordSocketReceiver(host=self.ctx.ip, 
                                               logger=self.logger)
     self.log_t = threading.Thread(target=self.log_server.serve_forever)
     self.log_t.start()
示例#10
0
文件: master.py 项目: awai0707/cola
class Master(object):
    def __init__(self, ctx):
        self.ctx = ctx
        self.rpc_server = self.ctx.master_rpc_server
        assert self.rpc_server is not None
        
        self.working_dir = os.path.join(self.ctx.working_dir, 'master')
        self.zip_dir = os.path.join(self.working_dir, 'zip')
        self.job_dir = os.path.join(self.working_dir, 'jobs')
        if not os.path.exists(self.zip_dir):
            os.makedirs(self.zip_dir)
        if not os.path.exists(self.job_dir):
            os.makedirs(self.job_dir)
        
        self.worker_tracker = WorkerTracker()
        self.job_tracker = JobTracker()

        self.black_list = []
        
        self.stopped = threading.Event()
        
        self.logger = get_logger("cola_master")
        self._init_log_server(self.logger)
        
        self._register_rpc()
        self.load()
        FileTransportServer(self.rpc_server, self.zip_dir)
        
    def load(self):
        self.runned_job_metas = {}
        
        job_meta_file = os.path.join(self.working_dir, JOB_META_STATUS_FILENAME)
        if os.path.exists(job_meta_file) and \
            os.path.getsize(job_meta_file) > 0:
            try:
                with open(job_meta_file) as f:
                    self.runned_job_metas = pickle.load(f)
            except:
                pass
    
    def save(self):
        job_meta_file = os.path.join(self.working_dir, JOB_META_STATUS_FILENAME)
        with open(job_meta_file, 'w') as f:
            pickle.dump(self.runned_job_metas, f)
        
    def _register_rpc(self):
        self.rpc_server.register_function(self.run_job, 'run_job')
        self.rpc_server.register_function(self.stop_job, 'stop_job')
        self.rpc_server.register_function(self.pack_job_error, 'pack_job_error')
        self.rpc_server.register_function(self.list_runnable_jobs, 
                                          'runnable_jobs')
        self.rpc_server.register_function(lambda: self.job_tracker.running_jobs.keys(),
                                          'running_jobs')
        self.rpc_server.register_function(self.list_workers,
                                          'list_workers')
        self.rpc_server.register_function(self.shutdown, 'shutdown')
        self.rpc_server.register_function(self.register_heartbeat, 
                                          'register_heartbeat')
        
    def register_heartbeat(self, worker):
        self.worker_tracker.register_worker(worker)
        return self.worker_tracker.workers.keys()
    
    def _init_log_server(self, logger):
        self.log_server = LogRecordSocketReceiver(host=self.ctx.ip, 
                                                  logger=self.logger)
        self.log_t = threading.Thread(target=self.log_server.serve_forever)
        self.log_t.start()
        
    def _shutdown_log_server(self):
        if hasattr(self, 'log_server'):
            self.log_server.shutdown()
            self.log_t.join()
    
    def _check_workers(self):
        while not self.stopped.is_set():
            for worker, info in self.worker_tracker.workers.iteritems():
                # if loose connection
                if int(time.time()) - info.last_update \
                    > HEARTBEAT_CHECK_INTERVAL:
                    
                    info.continous_register = 0
                    if info.status == RUNNING:
                        info.status = HANGUP
                    elif info.status == HANGUP:
                        info.status = STOPPED
                        self.black_list.append(worker)
                        
                        for job in self.job_tracker.running_jobs:
                            self.job_tracker.remove_worker(job, worker)
                        
                # if continously connect for more than 10 min
                elif info.continous_register >= CONTINOUS_HEARTBEAT:
                    if info.status != RUNNING:
                        info.status = RUNNING
                    if worker in self.black_list:
                        self.black_list.remove(worker)
                        
                    for job in self.job_tracker.running_jobs:
                        if not client_call(worker, 'has_job'):
                            client_call(worker, 'prepare', job)
                            client_call(worker, 'run_job', job)
                        self.job_tracker.add_worker(job, worker)
                
            self.stopped.wait(HEARTBEAT_CHECK_INTERVAL)
                        
    def _check_jobs(self):
        while not self.stopped.is_set():
            for job_master in self.job_tracker.running_jobs.values():
                if job_master.budget_server.get_status() == ALLFINISHED:
                    self.stop_job(job_master.job_name)
                    self.job_tracker.remove_job(job_master.job_name)
            self.stopped.wait(JOB_CHECK_INTERVAL)
                        
    def _unzip(self, job_name):
        zip_file = os.path.join(self.zip_dir, job_name+'.zip')
        if os.path.exists(zip_file):
            ZipHandler.uncompress(zip_file, self.job_dir)
            
    def _register_runned_job(self, job_name, job_desc):
        self.runned_job_metas[job_name] = {'job_name': job_desc.name,
                                           'created': time.time()}
                        
    def run(self):
        self._worker_t = threading.Thread(target=self._check_workers)
        self._worker_t.start()
        
        self._job_t = threading.Thread(target=self._check_jobs)
        self._job_t.start()
        
    def run_job(self, job_name, unzip=False, 
                wait_for_workers=False):
        if wait_for_workers:
            while not self.stopped.is_set():
                if len(self.worker_tracker.workers) > 0:
                    break
                stopped = self.stopped.wait(3)
                if stopped:
                    return

        if unzip:
            self._unzip(job_name)
        
        job_path = os.path.join(self.job_dir, job_name)
        job_desc = import_job_desc(job_path)
        job_master = JobMaster(self.ctx, job_name, job_desc, 
                               self.worker_tracker.workers.keys())
        job_master.init()
        self.job_tracker.register_job(job_name, job_master)
        self._register_runned_job(job_name, job_desc)
        
        zip_file = os.path.join(self.zip_dir, job_name+'.zip')
        for worker in job_master.workers:
            FileTransportClient(worker, zip_file).send_file()
        
        self.logger.debug(
            'entering the master prepare stage, job id: %s' % job_name)
        self.logger.debug(
            'job available workers: %s' % job_master.workers)
        stage = Stage(job_master.workers, 'prepare')
        stage.barrier(True, job_name)
        
        self.logger.debug(
            'entering the master run_job stage, job id: %s' % job_name)
        stage = Stage(job_master.workers, 'run_job')
        stage.barrier(True, job_name)
        
    def stop_job(self, job_name):
        job_master = self.job_tracker.get_job_master(job_name)
        stage = Stage(job_master.workers, 'stop_job')
        stage.barrier(True, job_name)
        
        stage = Stage(job_master.workers, 'clear_job')
        stage.barrier(True, job_name)
        
        self.job_tracker.remove_job(job_name)

        self.logger.debug('stop job: %s' % job_name)
        
    def pack_job_error(self, job_name):
        job_master = self.job_tracker.get_job_master(job_name)
        stage = Stage(job_master.workers, 'pack_job_error')
        stage.barrier(True, job_name)
        
        error_dir = os.path.join(self.working_dir, 'errors')
        if not os.path.exists(error_dir):
            os.makedirs(error_dir)
        error_filename = os.path.join(error_dir, '%s_errors.zip'%job_name)
        
        suffix = '%s_errors.zip' % job_name
        temp_dir = tempfile.mkdtemp()
        try:
            for name in os.listdir(self.zip_dir):
                if name.endswith(suffix):
                    shutil.move(os.path.join(self.zip_dir, name), temp_dir)
            ZipHandler.compress(error_filename, temp_dir)
        finally:
            shutil.rmtree(temp_dir)
            
        return error_filename
    
    def list_runnable_jobs(self):
        job_dirs = filter(lambda s: os.path.isdir(os.path.join(self.job_dir, s)), 
                          os.listdir(self.job_dir))
        
        jobs = {}
        for job_dir in job_dirs:
            desc = import_job_desc(os.path.join(self.job_dir, job_dir))
            jobs[job_dir] = desc.name
        return jobs
        
    def has_running_jobs(self):
        return len(self.job_tracker.running_jobs) > 0
    
    def list_workers(self):
        return [(worker, STATUSES[worker_info.status]) for worker, worker_info \
                in self.worker_tracker.workers.iteritems()]
        
    def _stop_all_jobs(self):
        for job_name in self.job_tracker.running_jobs.keys():
            self.stop_job(job_name)
            
    def _shutdown_all_workers(self):
        stage = Stage(self.worker_tracker.workers.keys(), 'shutdown')
        stage.barrier(True)
        
    def shutdown(self):
        if not hasattr(self, '_worker_t'):
            return
        if not hasattr(self, '_job_t'):
            return
        
        self.logger.debug('master starts to shutdown')
        
        self.stopped.set()
        self._stop_all_jobs()
        self._shutdown_all_workers()
        
        self._worker_t.join()
        self._job_t.join()
        
        self.save()
        self.rpc_server.shutdown()
        self.logger.debug('master shutdown finished')
        self._shutdown_log_server()
示例#11
0
文件: test_log.py 项目: 0pengl/cola
 def setUp(self):
     self.client_logger = get_logger(name='cola_test_client', server='localhost')
     self.server_logger = get_logger(name='cola_test_server')
     
     self.log_server = LogRecordSocketReceiver(logger=self.server_logger)
     threading.Thread(target=self.log_server.serve_forever).start()
示例#12
0
class MasterJobLoader(LimitionJobLoader, JobLoader):
    def __init__(self,
                 job,
                 data_dir,
                 nodes,
                 local_ip=None,
                 client=None,
                 context=None,
                 copies=1,
                 force=False):
        ctx = context or job.context
        master_port = ctx.job.master_port
        if local_ip is None:
            local_ip = get_ip()
        else:
            choices_ips = get_ips()
            if local_ip not in choices_ips:
                raise ValueError('IP address must be one of (%s)' %
                                 ','.join(choices_ips))
        local = '%s:%s' % (local_ip, master_port)

        JobLoader.__init__(self,
                           job,
                           data_dir,
                           local,
                           context=ctx,
                           copies=copies,
                           force=force)
        LimitionJobLoader.__init__(self, job, context=ctx)

        # check
        self.check()

        self.nodes = nodes
        self.not_registered = self.nodes[:]
        self.not_finished = self.nodes[:]

        # mq
        self.mq_client = MessageQueueClient(self.nodes, copies=copies)

        # lock
        self.ready_lock = threading.Lock()
        self.ready_lock.acquire()
        self.finish_lock = threading.Lock()
        self.finish_lock.acquire()

        # logger
        self.logger = get_logger(name='cola_master_%s' % self.job.real_name,
                                 filename=os.path.join(self.root, 'job.log'),
                                 is_master=True)
        self.client = client
        self.client_handler = None
        if self.client is not None:
            self.client_handler = add_log_client(self.logger, self.client)

        self.init_rpc_server()
        self.init_rate_clear()
        self.init_logger_server(self.logger)

        # register rpc server
        self.rpc_server.register_function(self.client_stop, 'client_stop')
        self.rpc_server.register_function(self.ready, 'ready')
        self.rpc_server.register_function(self.worker_finish, 'worker_finish')
        self.rpc_server.register_function(self.complete, 'complete')
        self.rpc_server.register_function(self.error, 'error')
        self.rpc_server.register_function(self.get_nodes, 'get_nodes')
        self.rpc_server.register_function(self.apply, 'apply')
        self.rpc_server.register_function(self.require, 'require')
        self.rpc_server.register_function(self.stop, 'stop')
        self.rpc_server.register_function(self.add_node, 'add_node')
        self.rpc_server.register_function(self.remove_node, 'remove_node')

        # register signal
        signal.signal(signal.SIGINT, self.signal_handler)
        signal.signal(signal.SIGTERM, self.signal_handler)

    def init_logger_server(self, logger):
        self.log_server = LogRecordSocketReceiver(host=get_ip(), logger=logger)
        threading.Thread(target=self.log_server.serve_forever).start()

    def stop_logger_server(self):
        if hasattr(self, 'log_server'):
            self.log_server.shutdown()

    def client_stop(self):
        if self.client_handler is not None:
            self.logger.removeHandler(self.client_handler)

    def check(self):
        env_legal = self.check_env(force=self.force)
        if not env_legal:
            raise JobMasterRunning('There has been a running job master.')

    def release_lock(self, lock):
        try:
            lock.release()
        except:
            pass

    def finish(self):
        self.release_lock(self.ready_lock)
        self.release_lock(self.finish_lock)

        LimitionJobLoader.finish(self)
        JobLoader.finish(self)
        self.stop_logger_server()

        try:
            for handler in self.logger.handlers:
                handler.close()
        except:
            pass

        if self.client is not None:
            rpc_client = '%s:%s' % (self.client.split(':')[0],
                                    main_conf.client.port)
            client_call(rpc_client, 'stop', ignore=True)

        self.stopped = True

    def stop(self):
        for node in self.nodes:
            client_call(node, 'stop', ignore=True)
        self.finish()

    def signal_handler(self, signum, frame):
        self.stop()

    def get_nodes(self):
        return self.nodes

    def ready(self, node):
        if node in self.not_registered:
            self.not_registered.remove(node)
            if len(self.not_registered) == 0:
                self.ready_lock.release()

    def worker_finish(self, node):
        if node in self.not_finished:
            self.not_finished.remove(node)
            if len(self.not_finished) == 0:
                self.finish_lock.release()

    def add_node(self, node):
        for node in self.nodes:
            client_call(node, 'add_node', node, ignore=True)
        self.nodes.append(node)
        client_call(node, 'run', ignore=True)

    def remove_node(self, node):
        for node in self.nodes:
            client_call(node, 'remove_node', node, ignore=True)
        if node in self.nodes:
            self.nodes.remove(node)

    def run(self):
        self.ready_lock.acquire()

        if not self.stopped and len(self.not_registered) == 0:
            self.mq_client.put(self.job.starts)
            for node in self.nodes:
                client_call(node, 'run')

        self.finish_lock.acquire()

        master_watcher = '%s:%s' % (get_ip(), main_conf.master.port)
        client_call(master_watcher,
                    'finish_job',
                    self.job.real_name,
                    ignore=True)

    def __enter__(self):
        return self

    def __exit__(self, type_, value, traceback):
        self.finish()
示例#13
0
 def init_logger_server(self, logger):
     self.log_server = LogRecordSocketReceiver(host=get_ip(), logger=logger)
     threading.Thread(target=self.log_server.serve_forever).start()