Esempio n. 1
0
    def __init__(self,
                 command,
                 host='127.0.0.1',
                 sources=1,
                 workers=1,
                 sinks=1,
                 sink_mode='framed',
                 worker_join_timeout=30,
                 is_ready_timeout=30,
                 res_dir=None,
                 persistent_data={}):
        # Create attributes
        self._finalized = False
        self._exited = False
        self._raised = False
        self.command = command
        self.host = host
        self.workers = TypedList(types=(Runner, ))
        self.dead_workers = TypedList(types=(Runner, ))
        self.restarted_workers = TypedList(types=(Runner, ))
        self.runners = TypedList(types=(Runner, ))
        self.source_addrs = []
        self.sink_addrs = []
        self.sinks = []
        self.senders = []
        self.worker_join_timeout = worker_join_timeout
        self.is_ready_timeout = is_ready_timeout
        self.metrics = Metrics(host, mode='framed')
        self.errors = []
        self._worker_id_counter = 0
        if res_dir is None:
            self.res_dir = tempfile.mkdtemp(dir='/tmp/', prefix='res-data.')
        else:
            self.res_dir = res_dir
        self.persistent_data = persistent_data
        # Run a continuous crash in a background thread
        self._stoppables = set()
        self.crash_checker = CrashChecker(self)
        self.crash_checker.start()

        # Try to start everything... clean up on exception
        try:
            setup_resilience_path(self.res_dir)

            self.metrics.start()
            self.metrics_addr = ":".join(
                map(str, self.metrics.get_connection_info()))

            for s in range(sinks):
                self.sinks.append(Sink(host, mode=sink_mode))
                self.sinks[-1].start()
                if self.sinks[-1].err is not None:
                    raise self.sinks[-1].err

            self.sink_addrs = [
                "{}:{}".format(*map(str, s.get_connection_info()))
                for s in self.sinks
            ]

            num_ports = sources + 3 * workers
            ports = get_port_values(num=num_ports, host=host)
            addresses = ['{}:{}'.format(host, p) for p in ports]
            (self.source_addrs, worker_addrs) = (addresses[:sources], [
                addresses[sources:][i:i + 3]
                for i in xrange(0, len(addresses[sources:]), 3)
            ])
            start_runners(self.workers, self.command, self.source_addrs,
                          self.sink_addrs, self.metrics_addr, self.res_dir,
                          workers, worker_addrs)
            self.runners.extend(self.workers)
            self._worker_id_counter = len(self.workers)

            # Wait for all runners to report ready to process
            self.wait_to_resume_processing(self.is_ready_timeout)
            # make sure `workers` runners are active and listed in the
            # cluster status query
            logging.log(1, "Testing cluster size via obs query")
            self.query_observability(cluster_status_query,
                                     self.runners[0].external,
                                     tests=[(worker_count_matches, [workers])])
            # start the crash checker
        except Exception as err:
            logging.error("Encountered and error when starting up the cluster")
            logging.exception(err)
            self.errors.append(err)
            self.__finally__()
            raise err
Esempio n. 2
0
class Cluster(object):
    def __init__(self,
                 command,
                 host='127.0.0.1',
                 sources=1,
                 workers=1,
                 sinks=1,
                 sink_mode='framed',
                 worker_join_timeout=30,
                 is_ready_timeout=30,
                 res_dir=None,
                 persistent_data={}):
        # Create attributes
        self._finalized = False
        self._exited = False
        self._raised = False
        self.command = command
        self.host = host
        self.workers = TypedList(types=(Runner, ))
        self.dead_workers = TypedList(types=(Runner, ))
        self.restarted_workers = TypedList(types=(Runner, ))
        self.runners = TypedList(types=(Runner, ))
        self.source_addrs = []
        self.sink_addrs = []
        self.sinks = []
        self.senders = []
        self.worker_join_timeout = worker_join_timeout
        self.is_ready_timeout = is_ready_timeout
        self.metrics = Metrics(host, mode='framed')
        self.errors = []
        self._worker_id_counter = 0
        if res_dir is None:
            self.res_dir = tempfile.mkdtemp(dir='/tmp/', prefix='res-data.')
        else:
            self.res_dir = res_dir
        self.persistent_data = persistent_data
        # Run a continuous crash in a background thread
        self._stoppables = set()
        self.crash_checker = CrashChecker(self)
        self.crash_checker.start()

        # Try to start everything... clean up on exception
        try:
            setup_resilience_path(self.res_dir)

            self.metrics.start()
            self.metrics_addr = ":".join(
                map(str, self.metrics.get_connection_info()))

            for s in range(sinks):
                self.sinks.append(Sink(host, mode=sink_mode))
                self.sinks[-1].start()
                if self.sinks[-1].err is not None:
                    raise self.sinks[-1].err

            self.sink_addrs = [
                "{}:{}".format(*map(str, s.get_connection_info()))
                for s in self.sinks
            ]

            num_ports = sources + 3 * workers
            ports = get_port_values(num=num_ports, host=host)
            addresses = ['{}:{}'.format(host, p) for p in ports]
            (self.source_addrs, worker_addrs) = (addresses[:sources], [
                addresses[sources:][i:i + 3]
                for i in xrange(0, len(addresses[sources:]), 3)
            ])
            start_runners(self.workers, self.command, self.source_addrs,
                          self.sink_addrs, self.metrics_addr, self.res_dir,
                          workers, worker_addrs)
            self.runners.extend(self.workers)
            self._worker_id_counter = len(self.workers)

            # Wait for all runners to report ready to process
            self.wait_to_resume_processing(self.is_ready_timeout)
            # make sure `workers` runners are active and listed in the
            # cluster status query
            logging.log(1, "Testing cluster size via obs query")
            self.query_observability(cluster_status_query,
                                     self.runners[0].external,
                                     tests=[(worker_count_matches, [workers])])
            # start the crash checker
        except Exception as err:
            logging.error("Encountered and error when starting up the cluster")
            logging.exception(err)
            self.errors.append(err)
            self.__finally__()
            raise err

    #############
    # Autoscale #
    #############
    def grow(self, by=1, timeout=30, with_test=True):
        logging.log(
            1, "grow(by={}, timeout={}, with_test={})".format(
                by, timeout, with_test))
        pre_partitions = self.get_partition_data() if with_test else None
        runners = []
        new_ports = get_port_values(num=3 * by,
                                    host=self.host,
                                    base_port=25000)
        addrs = [[
            "{}:{}".format(self.host, p) for p in new_ports[i * 3:i * 3 + 3]
        ] for i in range(by)]
        for x in range(by):
            runner = add_runner(worker_id=self._worker_id_counter,
                                runners=self.workers,
                                command=self.command,
                                source_addrs=self.source_addrs,
                                sink_addrs=self.sink_addrs,
                                metrics_addr=self.metrics_addr,
                                control_addr=self.workers[0].control,
                                res_dir=self.res_dir,
                                workers=by,
                                my_control_addr=addrs[x][0],
                                my_data_addr=addrs[x][1],
                                my_external_addr=addrs[x][2])
            self._worker_id_counter += 1
            runners.append(runner)
            self.runners.append(runner)
        if with_test:
            workers = {'joining': [w.name for w in runners], 'leaving': []}
            self.confirm_migration(pre_partitions, workers, timeout=timeout)
        return runners

    def shrink(self, workers=1, timeout=30, with_test=True):
        logging.log(
            1, "shrink(workers={}, with_test={})".format(workers, with_test))
        # pick a worker that's not being shrunk
        if isinstance(workers, basestring):
            snames = set(workers.split(","))
            wnames = set([w.name for w in self.workers])
            complement = wnames - snames  # all members of wnames not in snames
            if not complement:
                raise ValueError("Can't shrink all workers!")
            for w in self.workers:
                if w.name in complement:
                    address = w.external
                    break
            leaving = list(filter(lambda w: w.name in snames, self.workers))
        elif isinstance(workers, int):
            if len(self.workers) <= workers:
                raise ValueError("Can't shrink all workers!")
            # choose last workers in `self.workers`
            leaving = self.workers[-workers:]
            address = self.workers[0].external
        else:
            raise ValueError("shrink(workers): workers must be an int or a"
                             " comma delimited string of worker names.")
        names = ",".join((w.name for w in leaving))
        pre_partitions = self.get_partition_data() if with_test else None
        # send shrink command to non-shrinking worker
        logging.log(1, "Sending a shrink command for ({})".format(names))
        resp = send_shrink_command(address, names)
        logging.log(1, "Response was: {}".format(resp))
        # no error, so command was successful, update self.workers
        for w in leaving:
            self.workers.remove(w)
        self.dead_workers.extend(leaving)
        if with_test:
            workers = {'leaving': [w.name for w in leaving], 'joining': []}
            self.confirm_migration(pre_partitions, workers, timeout=timeout)
        return leaving

    def get_partition_data(self):
        logging.log(1, "get_partition_data()")
        addresses = [(w.name, w.external) for w in self.workers]
        responses = multi_states_query(addresses)
        return coalesce_partition_query_responses(responses)

    def confirm_migration(self, pre_partitions, workers, timeout=120):
        logging.log(
            1, "confirm_migration(pre_partitions={}, workers={},"
            " timeout={})".format(pre_partitions, workers, timeout))

        def pre_process():
            addresses = [(r.name, r.external) for r in self.workers]
            responses = multi_states_query(addresses)
            post_partitions = coalesce_partition_query_responses(responses)
            return (pre_partitions, post_partitions, workers)

        # retry the test until it passes or a timeout elapses
        logging.log(1, "Running pre_process func with try_until")
        tut = TryUntilTimeout(validate_migration,
                              pre_process,
                              timeout=timeout,
                              interval=2)
        self._stoppables.add(tut)
        tut.start()
        tut.join()
        self._stoppables.discard(tut)
        if tut.error:
            logging.error(
                "validate_partitions failed with inputs:"
                "(pre_partitions: {!r}, post_partitions: {!r}, workers: {!r})".
                format(*(tut.args if tut.args is not None else (None, None,
                                                                None))))
            raise tut.error

    #####################
    # Worker management #
    #####################
    def kill_worker(self, worker=-1):
        """
        Kill a worker, move it from `workers` to `dead_workers`, and return
        it.
        If `worker` is an int, perform this on the Runner instance at `worker`
        position in the `workers` list.
        If `worker` is a Runner instance, perform this on that instance.
        """
        logging.log(1, "kill_worker(worker={})".format(worker))
        if isinstance(worker, Runner):
            # ref to worker
            self.workers.remove(worker)
            r = worker
        else:  # index of worker in self.workers
            r = self.workers.pop(worker)
        r.kill()
        self.dead_workers.append(r)
        return r

    def stop_worker(self, worker=-1):
        """
        Stop a worker, move it from `workers` to `dead_workers`, and return
        it.
        If `worker` is an int, perform this on the Runner instance at `worker`
        position in the `workers` list.
        If `worker` is a Runner instance, perform this on that instance.
        """
        logging.log(1, "stop_worker(worker={})".format(worker))
        if isinstance(worker, Runner):
            # ref to worker
            r = self.workers.remove(worker)
        else:  # index of worker in self.workers
            r = self.workers.pop(worker)
        r.stop()
        self.dead_workers.append(r)
        return r

    def restart_worker(self, worker=-1):
        """
        Restart a worker(s) via the `respawn` method of a runner, then add the
        new Runner instance to `workers`.
        If `worker` is an int, perform this on the Runner instance at `worker`
        position in the `dead_workers` list.
        If `worker` is a Runner instance, perform this on that instance.
        If `worker` is a list of Runners, perform this on each.
        If `worker` is a slice, perform this on the slice of self.dead_workers
        """
        logging.log(1, "restart_worker(worker={})".format(worker))
        if isinstance(worker, Runner):
            # ref to dead worker instance
            old_rs = [worker]
        elif isinstance(worker, (list, tuple)):
            old_rs = worker
        elif isinstance(worker, slice):
            old_rs = self.dead_workers[worker]
        else:  # index of worker in self.dead_workers
            old_rs = [self.dead_workers[worker]]
        new_rs = []
        for r in old_rs:
            new_rs.append(r.respawn())
        for r in new_rs:
            r.start()
        time.sleep(0.05)

        # Wait until all worker processes have started
        new_rs = tuple(new_rs)

        def check_alive():
            [r.is_alive() for r in new_rs]

        tut = TryUntilTimeout(check_alive, timeout=5, interval=0.1)
        self._stoppables.add(tut)
        tut.start()
        tut.join()
        self._stoppables.discard(tut)
        if tut.error:
            logging.error("Bad starters: {}".format(new_rs))
            raise tut.error
        logging.debug("All new runners started successfully")

        for r in new_rs:
            self.restarted_workers.append(r)
            self.workers.append(r)
            self.runners.append(r)
        return new_rs

    def stop_workers(self):
        logging.log(1, "stop_workers()")
        for r in self.runners:
            r.stop()
        # move all live workers to dead_workers
        self.dead_workers.extend(self.workers)
        self.workers = []

    def get_crashed_workers(self,
                            func=lambda r: r.poll() not in (None, 0, -9, -15)):
        logging.log(1, "get_crashed_workers()")
        return list(filter(func, self.runners))

    #########
    # Sinks #
    #########
    def stop_sinks(self):
        logging.log(1, "stop_sinks()")
        for s in self.sinks:
            s.stop()

    def sink_await(self, values, timeout=30, func=lambda x: x, sink=-1):
        logging.log(
            1, "sink_await(values={}, timeout={}, func: {}, sink={})".format(
                values, timeout, func, sink))
        if isinstance(sink, Sink):
            pass
        else:
            sink = self.sinks[sink]
        t = SinkAwaitValue(sink, values, timeout, func)
        self._stoppables.add(t)
        t.start()
        t.join()
        self._stoppables.discard(t)
        if t.error:
            raise t.error
        return sink

    def sink_expect(self, expected, timeout=30, sink=-1, allow_more=False):
        logging.log(
            1, "sink_expect(expected={}, timeout={}, sink={})".format(
                expected, timeout, sink))
        if isinstance(sink, Sink):
            pass
        else:
            sink = self.sinks[sink]
        t = SinkExpect(sink, expected, timeout, allow_more=allow_more)
        self._stoppables.add(t)
        t.start()
        t.join()
        self._stoppables.discard(t)
        if t.error:
            raise t.error
        return sink

    ###########
    # Senders #
    ###########
    def add_sender(self, sender, start=False):
        logging.log(1, "add_sender(sender={}, start={})".format(sender, start))
        self.senders.append(sender)
        if start:
            sender.start()

    def wait_for_sender(self, sender=-1, timeout=30):
        logging.log(
            1,
            "wait_for_sender(sender={}, timeout={})".format(sender, timeout))
        if isinstance(sender, Sender):
            pass
        else:
            sender = self.senders[sender]
        self._stoppables.add(sender)
        sender.join(timeout)
        self._stoppables.discard(sender)
        if sender.error:
            raise sender.error
        if sender.is_alive():
            raise TimeoutError('Sender did not complete in the expected '
                               'period')

    def stop_senders(self):
        logging.log(1, "stop_senders()")
        for s in self.senders:
            s.stop()

    def pause_senders(self):
        logging.log(1, "pause_senders()")
        for s in self.senders:
            s.pause()
        self.wait_for_senders_to_flush()

    def wait_for_senders_to_flush(self, timeout=30):
        logging.log(1, "wait_for_senders_to_flush({})".format(timeout))
        awaiters = []
        for s in self.senders:
            a = TryUntilTimeout(validate_sender_is_flushed,
                                pre_process=(s, ),
                                timeout=timeout,
                                interval=0.1)
            self._stoppables.add(a)
            awaiters.append(a)
            a.start()
        for a in awaiters:
            a.join()
            self._stoppables.discard(a)
            if a.error:
                raise a.error
            else:
                logging.debug("Sender is fully flushed after pausing.")

    def resume_senders(self):
        logging.log(1, "resume_senders()")
        for s in self.senders:
            s.resume()

    ###########
    # Cluster #
    ###########
    def wait_to_resume_processing(self, timeout=30):
        logging.log(1, "wait_to_resume_processing(timeout={})".format(timeout))
        w = WaitForClusterToResumeProcessing(self.workers, timeout=timeout)
        self._stoppables.add(w)
        w.start()
        w.join()
        self._stoppables.discard(w)
        if w.error:
            raise w.error

    def stop_cluster(self):
        logging.log(1, "stop_cluster()")
        self.stop_senders()
        self.stop_workers()
        self.stop_sinks()

    #########################
    # Observability queries #
    #########################
    def query_observability(self, query, args, tests, timeout=30, period=2):
        logging.log(
            1, "query_observability(query={}, args={}, tests={}, "
            "timeout={}, period={})".format(query, args, tests, timeout,
                                            period))
        obs = ObservabilityNotifier(query, args, tests, timeout, period)
        self._stoppables.add(obs)
        obs.start()
        obs.join()
        self._stoppables.discard(obs)
        if obs.error:
            raise obs.error

    ###########################
    # Context Manager functions:
    ###########################
    def __enter__(self):
        return self

    def stop_background_threads(self, error=None):
        logging.log(1, "stop_background_threads({})".format(error))
        for s in self._stoppables:
            s.stop(error)
        self._stoppables.clear()

    def raise_from_error(self, error):
        logging.log(1, "raise_from_error({})".format(error))
        if self._raised:
            return
        self.stop_background_threads(error)
        self.__exit__(type(error), error, None)
        self._raised = True

    def __exit__(self, _type, _value, _traceback):
        logging.log(1, "__exit__({}, {}, {})".format(_type, _value,
                                                     _traceback))
        if self._exited:
            return
        # clean up any remaining runner processes
        if _type or _value or _traceback:
            logging.error('An error was raised in the Cluster context',
                          exc_info=(_type, _value, _traceback))
            self.stop_background_threads(_value)
        else:
            self.stop_background_threads()
        try:
            for w in self.workers:
                w.stop()
            for w in self.dead_workers:
                w.stop()
            # Wait on runners to finish waiting on their subprocesses to exit
            for w in self.runners:
                # Check thread ident to avoid error when joining an un-started
                # thread.
                if w.ident:  # ident is set when a thread is started
                    w.join(self.worker_join_timeout)
            alive = []
            while self.workers:
                w = self.workers.pop()
                if w.is_alive():
                    alive.append(w)
                else:
                    self.dead_workers.append(w)
            if alive:
                alive_names = ', '.join((w.name for w in alive))
                raise ClusterError("Runners [{}] failed to exit cleanly after"
                                   " {} seconds.".format(
                                       alive_names, self.worker_join_timeout))
            # check for workes that exited with a non-0 return code
            # note that workers killed in the previous step have code -15
            bad_exit = []
            for w in self.dead_workers:
                c = w.returncode()
                if c not in (0, -9, -15):  # -9: SIGKILL, -15: SIGTERM
                    bad_exit.append(w)
            if bad_exit:
                raise ClusterError("The following workers terminated with "
                                   "a bad exit code: {}".format([
                                       "(name: {}, pid: {}, code:{})".format(
                                           w.name, w.pid, w.returncode())
                                       for w in bad_exit
                                   ]))
        finally:
            self.__finally__()
            if self.errors:
                logging.error("Errors were encountered when running"
                              " the cluster")
                for e in self.errors:
                    logging.exception(e)
            self._exited = True

    def __finally__(self):
        logging.log(1, "__finally__()")
        self.stop_background_threads()
        if self._finalized:
            return
        logging.info("Doing final cleanup")
        for w in self.runners:
            w.kill()
        for s in self.sinks:
            s.stop()
        for s in self.senders:
            s.stop()
        self.metrics.stop()
        self.persistent_data['runner_data'] = [
            RunnerData(r.name, r.command, r.pid, r.returncode(),
                       r.get_output(), r.start_time) for r in self.runners
        ]
        self.persistent_data['sender_data'] = [
            SenderData(s.name, s.address, s.start_time, s.data)
            for s in self.senders
        ]
        self.persistent_data['sink_data'] = [
            SinkData(s.name, s.address, s.start_time, s.data)
            for s in self.sinks
        ]
        clean_resilience_path(self.res_dir)
        self._finalized = True