예제 #1
0
    def __init__(self, global_config, channel_config_loader, server_address):
        xmlrpc.server.SimpleXMLRPCServer.__init__(
            self,
            server_address,
            logRequests=False,
            requestHandler=RequestHandler)
        signal.signal(signal.SIGHUP, self.handle_sighup)
        self.channel_config_loader = channel_config_loader
        self.global_config = global_config
        self.dataspace = dataspace.DataSpace(self.global_config)
        self.reaper = Reaper(self.global_config)
        self.startup_complete = Event()
        self.logger = structlog.getLogger(LOGGERNAME)
        self.logger = self.logger.bind(module=__name__.split(".")[-1],
                                       channel=DELOGGER_CHANNEL_NAME)
        self.logger.debug(f"DecisionEngine starting on {server_address}")

        exchange_name = self.global_config.get("exchange_name",
                                               "hepcloud_topic_exchange")
        self.logger.debug(f"Creating topic exchange {exchange_name}")
        self.exchange = Exchange(exchange_name, "topic")
        self.broker_url = self.global_config.get("broker_url",
                                                 "redis://localhost:6379/0")
        _verify_redis_server(self.broker_url)

        self.source_workers = SourceWorkers(self.exchange, self.broker_url,
                                            self.logger)
        self.channel_workers = ChannelWorkers()

        self.register_function(self.rpc_metrics, name="metrics")

        self.logger.info(
            f"DecisionEngine __init__ complete {server_address} with {self.broker_url}"
        )
예제 #2
0
def main():  # pragma: no cover
    username = pwd.getpwuid(os.getuid()).pw_name
    if username not in ['root', 'decisionengine']:
        sys.exit(f"User '{username}' is not allowed to run this script.")

    config_file = policies.global_config_file()
    global_config = ValidConfig(config_file)
    reaper = Reaper(global_config)
    reaper.reap()
예제 #3
0
def reaper(request):
    config_fixture = request.getfixturevalue("config")
    reaper = Reaper(config_fixture)

    yield reaper

    with contextlib.suppress(Exception):
        if reaper.thread.is_alive() or not reaper.state.should_stop():
            reaper.state.set(State.OFFLINE)
            reaper.join(timeout=1)

    del reaper
    gc.collect()
예제 #4
0
def reaper(request):
    config_fixture = request.getfixturevalue("config")
    reaper = Reaper(config_fixture)

    yield reaper

    try:
        if reaper.thread.is_alive() or not reaper.state.should_stop():
            reaper.state.set(State.OFFLINE)
            reaper.join(timeout=1)
    except Exception:
        pass

    del reaper
    gc.collect()
예제 #5
0
    def __init__(self, global_config, channel_config_loader, server_address):
        xmlrpc.server.SimpleXMLRPCServer.__init__(
            self,
            server_address,
            logRequests=False,
            requestHandler=RequestHandler)

        self.logger = logging.getLogger("decision_engine")
        signal.signal(signal.SIGHUP, self.handle_sighup)
        self.workers = Workers()
        self.channel_config_loader = channel_config_loader
        self.global_config = global_config
        self.dataspace = dataspace.DataSpace(self.global_config)
        self.reaper = Reaper(self.global_config)
        self.logger.info("DecisionEngine started on {}".format(server_address))
예제 #6
0
 def test_fail_bad_config(self):
     with self.assertRaises(dataspace.DataSpaceConfigurationError):
         with mock.patch.object(dataspace.DataSourceLoader,
                                "create_datasource") as source:
             source.return_value = MockSource()
             test_config = GLOBAL_CONFIG.copy()
             test_config["dataspace"] = 'somestring'
             Reaper(test_config)
예제 #7
0
 def test_fail_wrong_config_key(self):
     with self.assertRaises(ValueError):
         with mock.patch.object(dataspace.DataSourceLoader,
                                "create_datasource") as source:
             source.return_value = MockSource()
             test_config = GLOBAL_CONFIG.copy()
             test_config["dataspace"]["retention_interval_in_days"] = "abc"
             Reaper(test_config)
예제 #8
0
    def __init__(self, global_config, channel_config_loader, server_address):
        xmlrpc.server.SimpleXMLRPCServer.__init__(
            self,
            server_address,
            logRequests=False,
            requestHandler=RequestHandler)

        signal.signal(signal.SIGHUP, self.handle_sighup)
        self.workers = Workers()
        self.channel_config_loader = channel_config_loader
        self.global_config = global_config
        self.dataspace = dataspace.DataSpace(self.global_config)
        self.reaper = Reaper(self.global_config)
        self.startup_complete = Event()
        self.logger = structlog.getLogger(LOGGERNAME)
        self.logger = self.logger.bind(module=__name__.split(".")[-1],
                                       channel=DELOGGER_CHANNEL_NAME)
        self.logger.info(f"DecisionEngine started on {server_address}")
예제 #9
0
    def __init__(self, global_config, channel_config_loader, server_address):
        xmlrpc.server.SimpleXMLRPCServer.__init__(
            self, server_address, logRequests=False, requestHandler=RequestHandler
        )
        signal.signal(signal.SIGHUP, self.handle_sighup)
        self.source_workers = {}
        self.channel_workers = Workers()
        self.channel_config_loader = channel_config_loader
        self.global_config = global_config
        self.dataspace = dataspace.DataSpace(self.global_config)
        self.reaper = Reaper(self.global_config)
        self.startup_complete = Event()
        self.logger = structlog.getLogger(LOGGERNAME)
        self.logger = self.logger.bind(module=__name__.split(".")[-1], channel=DELOGGER_CHANNEL_NAME)
        self.logger.info(f"DecisionEngine started on {server_address}")
        self.register_function(self.rpc_metrics, name="metrics")
        if not global_config.get("no_webserver"):
            self.start_webserver()

        self.broker_url = self.global_config.get("broker_url", "redis://localhost:6379/0")
        _verify_redis_server(self.broker_url)
예제 #10
0
class DecisionEngine(socketserver.ThreadingMixIn,
                     xmlrpc.server.SimpleXMLRPCServer):
    def __init__(self, global_config, channel_config_loader, server_address):
        xmlrpc.server.SimpleXMLRPCServer.__init__(
            self,
            server_address,
            logRequests=False,
            requestHandler=RequestHandler)

        self.logger = logging.getLogger("decision_engine")
        signal.signal(signal.SIGHUP, self.handle_sighup)
        self.workers = Workers()
        self.channel_config_loader = channel_config_loader
        self.global_config = global_config
        self.dataspace = dataspace.DataSpace(self.global_config)
        self.reaper = Reaper(self.global_config)
        self.startup_complete = Event()
        self.logger.info("DecisionEngine started on {}".format(server_address))

    def get_logger(self):
        return self.logger

    def _dispatch(self, method, params):
        try:
            # methods allowed to be executed by rpc have 'rpc_' pre-pended
            func = getattr(self, "rpc_" + method)
        except AttributeError:
            raise Exception(f'method "{method}" is not supported')
        return func(*params)

    def block_until(self, state, timeout=None):
        with self.workers.unguarded_access() as workers:
            if not workers:
                self.logger.info('No active channels to wait on.')
                return 'No active channels.'
            for tm in workers.values():
                if tm.is_alive():
                    tm.wait_until(state, timeout)
        return f'No channels in {state} state.'

    def block_while(self, state, timeout=None):
        with self.workers.unguarded_access() as workers:
            if not workers:
                self.logger.info('No active channels to wait on.')
                return 'No active channels.'
            for tm in workers.values():
                if tm.is_alive():
                    tm.wait_while(state, timeout)
        return f'No channels in {state} state.'

    def _dataframe_to_table(self, df):
        return "{}\n".format(
            tabulate.tabulate(df, headers='keys', tablefmt='psql'))

    def _dataframe_to_vertical_tables(self, df):
        txt = ""
        for i in range(len(df)):
            txt += f"Row {i}\n"
            txt += "{}\n".format(
                tabulate.tabulate(df.T.iloc[:, [i]], tablefmt='psql'))
        return txt

    def _dataframe_to_column_names(self, df):
        columns = df.columns.values.reshape([len(df.columns), 1])
        return "{}\n".format(
            tabulate.tabulate(columns, headers=['columns'], tablefmt='psql'))

    def _dataframe_to_json(self, df):
        return "{}\n".format(json.dumps(json.loads(df.to_json()), indent=4))

    def _dataframe_to_csv(self, df):
        return "{}\n".format(df.to_csv())

    def rpc_block_while(self, state_str, timeout=None):
        allowed_state = None
        try:
            allowed_state = ProcessingState.State[state_str]
        except Exception:
            return f'{state_str} is not a valid channel state.'
        return self.block_while(allowed_state, timeout)

    def rpc_show_config(self, channel):
        """
        Show the configuration for a channel.

        :type channel: string
        """
        txt = ""
        channels = self.channel_config_loader.get_channels()
        if channel == 'all':
            for ch in channels:
                txt += _channel_preamble(ch)
                txt += self.channel_config_loader.print_channel_config(ch)
            return txt

        if channel not in channels:
            return f"There is no active channel named {channel}."

        txt += _channel_preamble(channel)
        txt += self.channel_config_loader.print_channel_config(channel)
        return txt

    def rpc_show_de_config(self):
        return self.global_config.dump()

    def rpc_print_product(self,
                          product,
                          columns=None,
                          query=None,
                          types=False,
                          format=None):
        found = False
        txt = "Product {}: ".format(product)
        with self.workers.access() as workers:
            for ch, worker in workers.items():
                if not worker.is_alive():
                    txt += f"Channel {ch} is in not active\n"
                    continue

                produces = worker.get_produces()
                r = [x for x in list(produces.items()) if product in x[1]]
                if not r:
                    continue
                found = True
                txt += " Found in channel {}\n".format(ch)
                tm = self.dataspace.get_taskmanager(ch)
                try:
                    data_block = datablock.DataBlock(
                        self.dataspace,
                        ch,
                        taskmanager_id=tm['taskmanager_id'],
                        sequence_id=tm['sequence_id'])
                    data_block.generation_id -= 1
                    df = data_block[product]
                    df = pd.read_json(df.to_json())
                    dataframe_formatter = self._dataframe_to_table
                    if format == 'vertical':
                        dataframe_formatter = self._dataframe_to_vertical_tables
                    if format == 'column-names':
                        dataframe_formatter = self._dataframe_to_column_names
                    if format == 'json':
                        dataframe_formatter = self._dataframe_to_json
                    if types:
                        for column in df.columns:
                            df.insert(
                                df.columns.get_loc(column) + 1,
                                f"{column}.type", df[column].transform(
                                    lambda x: type(x).__name__))
                    column_names = []
                    if columns:
                        column_names = columns.split(",")
                    if query:
                        if column_names:
                            txt += dataframe_formatter(
                                df.loc[:, column_names].query(query))
                        else:
                            txt += dataframe_formatter(df.query(query))

                    else:
                        if column_names:
                            txt += dataframe_formatter(df.loc[:, column_names])
                        else:
                            txt += dataframe_formatter(df)
                except Exception as e:  # pragma: no cover
                    txt += "\t\t{}\n".format(e)
        if not found:
            txt += "Not produced by any module\n"
        return txt[:-1]

    def rpc_print_products(self):
        with self.workers.access() as workers:
            channel_keys = workers.keys()
            if not channel_keys:
                return "No channels are currently active.\n"

            width = max([len(x) for x in channel_keys]) + 1
            txt = ""
            for ch, worker in workers.items():
                if not worker.is_alive():
                    txt += f"Channel {ch} is in ERROR state\n"
                    continue

                txt += "channel: {:<{width}}, id = {:<{width}}, state = {:<10} \n".format(
                    ch,
                    worker.task_manager_id,
                    worker.get_state_name(),
                    width=width)
                tm = self.dataspace.get_taskmanager(ch)
                data_block = datablock.DataBlock(
                    self.dataspace,
                    ch,
                    taskmanager_id=tm['taskmanager_id'],
                    sequence_id=tm['sequence_id'])
                data_block.generation_id -= 1
                channel_config = self.channel_config_loader.get_channels()[ch]
                produces = worker.get_produces()
                for i in ("sources", "transforms", "logicengines",
                          "publishers"):
                    txt += "\t{}:\n".format(i)
                    modules = channel_config.get(i, {})
                    for mod_name, mod_config in modules.items():
                        txt += "\t\t{}\n".format(mod_name)
                        products = produces.get(mod_name, [])
                        for product in products:
                            try:
                                df = data_block[product]
                                df = pd.read_json(df.to_json())
                                txt += "{}\n".format(
                                    tabulate.tabulate(df,
                                                      headers='keys',
                                                      tablefmt='psql'))
                            except Exception as e:  # pragma: no cover
                                txt += "\t\t\t{}\n".format(e)
        return txt[:-1]

    def rpc_status(self):
        with self.workers.access() as workers:
            channel_keys = workers.keys()
            if not channel_keys:
                return "No channels are currently active.\n" + self.reaper_status(
                )

            txt = ""
            width = max([len(x) for x in channel_keys]) + 1
            for ch, worker in workers.items():
                txt += "channel: {:<{width}}, id = {:<{width}}, state = {:<10} \n".format(
                    ch,
                    worker.task_manager_id,
                    worker.get_state_name(),
                    width=width)
                produces = worker.get_produces()
                consumes = worker.get_consumes()
                channel_config = self.channel_config_loader.get_channels()[ch]
                for i in ("sources", "transforms", "logicengines",
                          "publishers"):
                    txt += "\t{}:\n".format(i)
                    modules = channel_config.get(i, {})
                    for mod_name, mod_config in modules.items():
                        txt += "\t\t{}\n".format(mod_name)
                        txt += "\t\t\tconsumes : {}\n".format(
                            consumes.get(mod_name, []))
                        txt += "\t\t\tproduces : {}\n".format(
                            produces.get(mod_name, []))
        return txt + self.reaper_status()

    def rpc_stop(self):
        self.shutdown()
        self.stop_channels()
        self.reaper_stop()
        return "OK"

    def start_channel(self, channel_name, channel_config):
        generation_id = 1
        task_manager = TaskManager.TaskManager(channel_name, generation_id,
                                               channel_config,
                                               self.global_config)
        worker = Worker(task_manager, self.global_config['logger'])
        with self.workers.access() as workers:
            workers[channel_name] = worker
        self.logger.debug(f"Trying to start {channel_name}")
        worker.start()
        worker.wait_while(ProcessingState.State['BOOT'])
        self.logger.info(f"Channel {channel_name} started")

    def start_channels(self):
        self.channel_config_loader.load_all_channels()

        if not self.channel_config_loader.get_channels():
            self.logger.info(
                "No channel configurations available in " +
                f"{self.channel_config_loader.channel_config_dir}")
        else:
            self.logger.debug(
                f"Found channels: {self.channel_config_loader.get_channels().items()}"
            )

        for name, config in self.channel_config_loader.get_channels().items():
            try:
                self.start_channel(name, config)
            except Exception as e:
                self.logger.exception(f"Channel {name} failed to start : {e}")

    def rpc_start_channel(self, channel_name):
        with self.workers.access() as workers:
            if channel_name in workers:
                return f"ERROR, channel {channel_name} is running"

        success, result = self.channel_config_loader.load_channel(channel_name)
        if not success:
            return result
        self.start_channel(channel_name, result)
        return "OK"

    def rpc_start_channels(self):
        self.start_channels()
        return "OK"

    def rpc_stop_channel(self, channel):
        return self.rpc_rm_channel(channel, None)

    def rpc_kill_channel(self, channel, timeout=None):
        if timeout is None:
            timeout = self.global_config.get("shutdown_timeout", 10)
        return self.rpc_rm_channel(channel, timeout)

    def rpc_rm_channel(self, channel, maybe_timeout):
        rc = self.rm_channel(channel, maybe_timeout)
        if rc == StopState.NotFound:
            return f"No channel found with the name {channel}."
        elif rc == StopState.Terminated:
            if maybe_timeout == 0:
                return f"Channel {channel} has been killed."
            # Would be better to use something like the inflect
            # module, but that introduces another dependency.
            suffix = 's' if maybe_timeout > 1 else ''
            return f"Channel {channel} has been killed due to shutdown timeout ({maybe_timeout} second{suffix})."
        assert rc == StopState.Clean
        return f"Channel {channel} stopped cleanly."

    def rm_channel(self, channel, maybe_timeout):
        rc = None
        with self.workers.access() as workers:
            if channel not in workers:
                return StopState.NotFound
            self.logger.debug(f"Trying to stop {channel}")
            rc = self.stop_worker(workers[channel], maybe_timeout)
            del workers[channel]
        return rc

    def stop_worker(self, worker, timeout):
        if worker.is_alive():
            self.logger.debug("Trying to shutdown worker")
            worker.task_manager.set_to_shutdown()
            self.logger.debug("Trying to take worker offline")
            worker.task_manager.take_offline(None)
            worker.join(timeout)
        if worker.exitcode is None:
            worker.terminate()
            return StopState.Terminated
        else:
            return StopState.Clean

    def stop_channels(self):
        timeout = self.global_config.get("shutdown_timeout", 10)
        with self.workers.access() as workers:
            for worker in workers.values():
                self.stop_worker(worker, timeout)
            workers.clear()

    def rpc_stop_channels(self):
        self.stop_channels()
        return "All channels stopped."

    def handle_sighup(self, signum, frame):
        self.reaper_stop()
        self.stop_channels()
        self.start_channels()
        self.reaper_start(delay=self.global_config['dataspace'].get(
            'reaper_start_delay_seconds', 1818))

    def rpc_get_log_level(self):
        engineloglevel = self.get_logger().getEffectiveLevel()
        return logging.getLevelName(engineloglevel)

    def rpc_get_channel_log_level(self, channel):
        with self.workers.access() as workers:
            if channel not in workers:
                return f"No channel found with the name {channel}."

            worker = workers[channel]
            if not worker.is_alive():
                return f"Channel {channel} is in ERROR state."
            return logging.getLevelName(worker.task_manager.get_loglevel())

    def rpc_set_channel_log_level(self, channel, log_level):
        """Assumes log_level is a string corresponding to the supported logging-module levels."""
        with self.workers.access() as workers:
            if channel not in workers:
                return f"No channel found with the name {channel}."

            worker = workers[channel]
            if not worker.is_alive():
                return f"Channel {channel} is in ERROR state."

            log_level_code = getattr(logging, log_level)
            if worker.task_manager.get_loglevel() == log_level_code:
                return f"Nothing to do. Current log level is : {log_level}"
            worker.task_manager.set_loglevel_value(log_level)
        return f"Log level changed to : {log_level}"

    def rpc_reaper_start(self, delay=0):
        '''
        Start the reaper process after 'delay' seconds.
        Default 0 seconds delay.
        :type delay: int
        '''
        self.reaper_start(delay)
        return "OK"

    def reaper_start(self, delay):
        self.reaper.start(delay)

    def rpc_reaper_stop(self):
        self.reaper_stop()
        return "OK"

    def reaper_stop(self):
        self.reaper.stop()

    def rpc_reaper_status(self):
        interval = self.reaper.retention_interval
        state = self.reaper.state.get()
        txt = 'reaper:\n\tstate: {}\n\tretention_interval: {}'.format(
            state, interval)
        return txt

    def reaper_status(self):
        interval = self.reaper.retention_interval
        state = self.reaper.state.get()
        txt = '\nreaper:\n\tstate: {}\n\tretention_interval: {}\n'.format(
            state, interval)
        return txt

    def rpc_query_tool(self, product, format=None, start_time=None):
        found = False
        result = pd.DataFrame()
        txt = "Product {}: ".format(product)

        with self.workers.access() as workers:
            for ch, worker in workers.items():
                if not worker.is_alive():
                    txt += f"Channel {ch} is in not active\n"
                    continue

                produces = worker.get_produces()
                r = [x for x in list(produces.items()) if product in x[1]]
                if not r:
                    continue
                found = True
                txt += " Found in channel {}\n".format(ch)

                if start_time:
                    tms = self.dataspace.get_taskmanagers(
                        ch, start_time=start_time)
                else:
                    tms = [self.dataspace.get_taskmanager(ch)]
                for tm in tms:
                    try:
                        data_block = datablock.DataBlock(
                            self.dataspace,
                            ch,
                            taskmanager_id=tm['taskmanager_id'],
                            sequence_id=tm['sequence_id'])
                        products = data_block.get_dataproducts(product)
                        for p in products:
                            df = p["value"]
                            if df.shape[0] > 0:
                                df["channel"] = [tm["name"]] * df.shape[0]
                                df["taskmanager_id"] = [p["taskmanager_id"]
                                                        ] * df.shape[0]
                                df["generation_id"] = [p["generation_id"]
                                                       ] * df.shape[0]
                                result = result.append(df)
                    except Exception as e:  # pragma: no cover
                        txt += "\t\t{}\n".format(e)

        if found:
            dataframe_formatter = self._dataframe_to_table
            if format == "csv":
                dataframe_formatter = self._dataframe_to_csv
            if format == "json":
                dataframe_formatter = self._dataframe_to_json
            result = result.reset_index(drop=True)
            txt += dataframe_formatter(result)
        else:
            txt += "Not produced by any module\n"

        return txt
예제 #11
0
class TestReaper(unittest.TestCase):
    logger = logging.getLogger()

    def setUp(self):
        with mock.patch.object(dataspace.DataSourceLoader,
                               "create_datasource") as source:
            source.return_value = MockSource()
            GLOBAL_CONFIG["dataspace"]["retention_interval_in_days"] = 365
            self.reaper = Reaper(GLOBAL_CONFIG)

    def tearDown(self):
        # Make sure there are no dangling reapers
        try:
            if self.reaper.thread.is_alive(
            ) or not self.reaper.state.should_stop():
                self.reaper.state.set(State.OFFLINE)
                time.sleep(0.5)
        except Exception:
            pass

    def test_reap_default_state(self):
        self.assertEqual(self.reaper.state.get(), State.BOOT)

    def test_reaper_can_reap(self):
        self.reaper.reap()

    def test_just_stop_no_error(self):
        self.reaper.stop()

    def test_start_stop(self):
        self.reaper.start()
        self.assertIn(self.reaper.state.get(),
                      (State.IDLE, State.ACTIVE, State.STEADY))

        self.reaper.stop()
        self.assertIn(self.reaper.state.get(),
                      (State.SHUTTINGDOWN, State.SHUTDOWN))

    def test_start_stop_stop(self):
        self.reaper.start()
        self.assertIn(self.reaper.state.get(),
                      (State.IDLE, State.ACTIVE, State.STEADY))

        self.reaper.stop()
        self.assertIn(self.reaper.state.get(),
                      (State.SHUTTINGDOWN, State.SHUTDOWN))

        self.logger.debug("running second stop")
        self.reaper.stop()
        self.assertIn(self.reaper.state.get(),
                      (State.SHUTTINGDOWN, State.SHUTDOWN))

    def test_state_can_be_active(self):
        self.reaper.start()
        time.sleep(0.5)  # make sure reaper has a chance to get the lock
        self.assertEqual(self.reaper.state.get(), State.ACTIVE)

    @pytest.mark.timeout(20)
    def test_state_sets_timer_and_uses_it(self):
        self.reaper.MIN_SECONDS_BETWEEN_RUNS = 1
        self.reaper.seconds_between_runs = 1
        self.reaper.start(delay=2)
        self.assertEqual(self.reaper.seconds_between_runs, 1)
        self.reaper.state.wait_while(
            State.IDLE)  # Make sure the reaper started
        self.assertEqual(self.reaper.state.get(), State.ACTIVE)
        self.reaper.state.wait_while(
            State.ACTIVE)  # let the reaper finish its scan
        self.reaper.state.wait_while(
            State.IDLE)  # Make sure the reaper started a second time
        self.reaper.state.wait_while(
            State.ACTIVE)  # let the reaper finish its scan

    def test_start_delay(self):
        self.reaper.start(delay=90)
        self.assertEqual(self.reaper.state.get(), State.IDLE)

    @pytest.mark.timeout(20)
    def test_loop_of_start_stop_in_clumps(self):
        for _ in range(3):
            self.logger.debug(f"run {_} of rapid start/stop")
            self.reaper.start()
            self.assertIn(self.reaper.state.get(),
                          (State.IDLE, State.ACTIVE, State.STEADY))
            self.reaper.stop()
            self.assertIn(self.reaper.state.get(),
                          (State.SHUTTINGDOWN, State.SHUTDOWN))

    def test_fail_small_retain(self):
        with self.assertRaises(ValueError):
            self.reaper.retention_interval = 1

    def test_fail_small_run_interval(self):
        with self.assertRaises(ValueError):
            self.reaper.seconds_between_runs = 1

    def test_fail_start_two_reapers(self):
        self.reaper.start()
        self.assertIn(self.reaper.state.get(),
                      (State.IDLE, State.ACTIVE, State.STEADY))
        with self.assertRaises(RuntimeError):
            self.logger.debug("running second start")
            self.reaper.start()

    def test_fail_missing_config(self):
        with self.assertRaises(dataspace.DataSpaceConfigurationError):
            with mock.patch.object(dataspace.DataSourceLoader,
                                   "create_datasource") as source:
                source.return_value = MockSource()
                test_config = GLOBAL_CONFIG.copy()
                del test_config["dataspace"]
                Reaper(test_config)

    def test_fail_bad_config(self):
        with self.assertRaises(dataspace.DataSpaceConfigurationError):
            with mock.patch.object(dataspace.DataSourceLoader,
                                   "create_datasource") as source:
                source.return_value = MockSource()
                test_config = GLOBAL_CONFIG.copy()
                test_config["dataspace"] = 'somestring'
                Reaper(test_config)

    def test_fail_missing_config_key(self):
        with self.assertRaises(dataspace.DataSpaceConfigurationError):
            with mock.patch.object(dataspace.DataSourceLoader,
                                   "create_datasource") as source:
                source.return_value = MockSource()
                test_config = GLOBAL_CONFIG.copy()
                del test_config["dataspace"]["retention_interval_in_days"]
                Reaper(test_config)

    def test_fail_wrong_config_key(self):
        with self.assertRaises(ValueError):
            with mock.patch.object(dataspace.DataSourceLoader,
                                   "create_datasource") as source:
                source.return_value = MockSource()
                test_config = GLOBAL_CONFIG.copy()
                test_config["dataspace"]["retention_interval_in_days"] = "abc"
                Reaper(test_config)

    @pytest.mark.timeout(20)
    def test_source_fail_can_be_fixed(self):
        with mock.patch.object(MockSource,
                               "delete_data_older_than") as function:
            function.side_effect = KeyError
            self.reaper.start()
            time.sleep(
                1)  # make sure stack trace bubbles up before checking state
            self.assertEqual(self.reaper.state.get(), State.ERROR)

            self.reaper.stop()
            self.assertEqual(self.reaper.state.get(), State.ERROR)

            function.side_effect = None
            self.reaper.start(delay=30)
            self.assertEqual(self.reaper.state.get(), State.IDLE)

            self.reaper.stop()
            self.assertEqual(self.reaper.state.get(), State.SHUTDOWN)
예제 #12
0
def test_fail_wrong_config_key(reaper, config):
    config["dataspace"]["retention_interval_in_days"] = "abc"
    with pytest.raises(ValueError):
        Reaper(config)
예제 #13
0
def test_fail_missing_config_key(reaper, config):
    del config["dataspace"]["retention_interval_in_days"]
    with pytest.raises(dataspace.DataSpaceConfigurationError):
        Reaper(config)
예제 #14
0
def test_fail_bad_config(reaper, config):
    config["dataspace"] = "somestring"
    with pytest.raises(dataspace.DataSpaceConfigurationError):
        Reaper(config)
예제 #15
0
def test_fail_missing_config(reaper, config):
    del config["dataspace"]
    with pytest.raises(dataspace.DataSpaceConfigurationError):
        Reaper(config)
예제 #16
0
 def setUp(self):
     with mock.patch.object(dataspace.DataSourceLoader,
                            "create_datasource") as source:
         source.return_value = MockSource()
         GLOBAL_CONFIG["dataspace"]["retention_interval_in_days"] = 365
         self.reaper = Reaper(GLOBAL_CONFIG)
예제 #17
0
class DecisionEngine(socketserver.ThreadingMixIn,
                     xmlrpc.server.SimpleXMLRPCServer):
    def __init__(self, global_config, channel_config_loader, server_address):
        xmlrpc.server.SimpleXMLRPCServer.__init__(
            self,
            server_address,
            logRequests=False,
            requestHandler=RequestHandler)
        signal.signal(signal.SIGHUP, self.handle_sighup)
        self.channel_config_loader = channel_config_loader
        self.global_config = global_config
        self.dataspace = dataspace.DataSpace(self.global_config)
        self.reaper = Reaper(self.global_config)
        self.startup_complete = Event()
        self.logger = structlog.getLogger(LOGGERNAME)
        self.logger = self.logger.bind(module=__name__.split(".")[-1],
                                       channel=DELOGGER_CHANNEL_NAME)
        self.logger.debug(f"DecisionEngine starting on {server_address}")

        exchange_name = self.global_config.get("exchange_name",
                                               "hepcloud_topic_exchange")
        self.logger.debug(f"Creating topic exchange {exchange_name}")
        self.exchange = Exchange(exchange_name, "topic")
        self.broker_url = self.global_config.get("broker_url",
                                                 "redis://localhost:6379/0")
        _verify_redis_server(self.broker_url)

        self.source_workers = SourceWorkers(self.exchange, self.broker_url,
                                            self.logger)
        self.channel_workers = ChannelWorkers()

        self.register_function(self.rpc_metrics, name="metrics")

        self.logger.info(
            f"DecisionEngine __init__ complete {server_address} with {self.broker_url}"
        )

    def get_logger(self):
        return self.logger

    def _dispatch(self, method, params):
        try:
            # methods allowed to be executed by rpc have 'rpc_' pre-pended
            func = getattr(self, "rpc_" + method)
        except AttributeError:
            raise Exception(f'method "{method}" is not supported')
        return func(*params)

    def block_while(self, state, timeout=None):
        self.logger.debug(
            f"Waiting for {state} or timeout={timeout} on channel_workers.")
        with self.channel_workers.unguarded_access() as workers:
            if not workers:
                self.logger.info("No active channels to wait on.")
                return "No active channels."
            for tm in workers.values():
                if tm.is_alive():
                    self.logger.debug(
                        f"Waiting for {tm.task_manager.name} to exit {state} state."
                    )
                    tm.wait_while(state, timeout)
        return f"No channels in {state} state."

    def _dataframe_to_table(self, df):
        return f"{tabulate.tabulate(df, headers='keys', tablefmt='psql')}\n"

    def _dataframe_to_vertical_tables(self, df):
        txt = ""
        for i in range(len(df)):
            txt += f"Row {i}\n"
            txt += f"{tabulate.tabulate(df.T.iloc[:, [i]], tablefmt='psql')}\n"
        return txt

    def _dataframe_to_column_names(self, df):
        columns = df.columns.values.reshape([len(df.columns), 1])
        return f"{tabulate.tabulate(columns, headers=['columns'], tablefmt='psql')}\n"

    def _dataframe_to_json(self, df):
        return f"{json.dumps(json.loads(df.to_json()), indent=4)}\n"

    def _dataframe_to_csv(self, df):
        return f"{df.to_csv()}\n"

    def rpc_ping(self):
        return "pong"

    def rpc_block_while(self, state_str, timeout=None):
        allowed_state = None
        try:
            allowed_state = ProcessingState.State[state_str]
        except Exception:
            return f"{state_str} is not a valid channel state."
        return self.block_while(allowed_state, timeout)

    def rpc_show_config(self, channel):
        """
        Show the configuration for a channel.

        :type channel: string
        """
        txt = ""
        channels = self.channel_config_loader.get_channels()
        if channel == "all":
            for ch in channels:
                txt += _channel_preamble(ch)
                txt += self.channel_config_loader.print_channel_config(ch)
            return txt

        if channel not in channels:
            return f"There is no active channel named {channel}."

        txt += _channel_preamble(channel)
        txt += self.channel_config_loader.print_channel_config(channel)
        return txt

    def rpc_show_de_config(self):
        return self.global_config.dump()

    @PRINT_PRODUCT_HISTOGRAM.time()
    def rpc_print_product(self,
                          product,
                          columns=None,
                          query=None,
                          types=False,
                          format=None):
        if not isinstance(product, str):
            raise ValueError(
                f"Requested product should be a string not {type(product)}")

        found = False
        txt = f"Product {product}: "
        with self.channel_workers.access() as workers:
            for ch, worker in workers.items():
                if not worker.is_alive():
                    txt += f"Channel {ch} is in not active\n"
                    self.logger.debug(
                        f"Channel:{ch} is in not active when running rpc_print_product"
                    )
                    continue

                produces = worker.get_produces()
                r = [x for x in list(produces.items()) if product in x[1]]
                if not r:
                    continue
                found = True
                txt += f" Found in channel {ch}\n"
                self.logger.debug(
                    f"Found channel:{ch} active when running rpc_print_product"
                )
                tm = self.dataspace.get_taskmanager(ch)
                self.logger.debug(
                    f"rpc_print_product - channel:{ch} taskmanager:{tm}")
                try:
                    data_block = datablock.DataBlock(
                        self.dataspace,
                        ch,
                        taskmanager_id=tm["taskmanager_id"],
                        sequence_id=tm["sequence_id"])
                    data_block.generation_id -= 1
                    df = data_block[product]
                    dfj = df.to_json()
                    self.logger.debug(
                        f"rpc_print_product - channel:{ch} task manager:{tm} datablock:{dfj}"
                    )
                    df = pd.read_json(dfj)
                    dataframe_formatter = self._dataframe_to_table
                    if format == "vertical":
                        dataframe_formatter = self._dataframe_to_vertical_tables
                    if format == "column-names":
                        dataframe_formatter = self._dataframe_to_column_names
                    if format == "json":
                        dataframe_formatter = self._dataframe_to_json
                    if types:
                        for column in df.columns:
                            df.insert(
                                df.columns.get_loc(column) + 1,
                                f"{column}.type",
                                df[column].transform(
                                    lambda x: type(x).__name__),
                            )
                    column_names = []
                    if columns:
                        column_names = columns.split(",")
                    if query:
                        if column_names:
                            txt += dataframe_formatter(
                                df.loc[:, column_names].query(query))
                        else:
                            txt += dataframe_formatter(df.query(query))

                    else:
                        if column_names:
                            txt += dataframe_formatter(df.loc[:, column_names])
                        else:
                            txt += dataframe_formatter(df)
                except Exception as e:  # pragma: no cover
                    txt += f"\t\t{e}\n"
        if not found:
            txt += "Not produced by any module\n"
        return txt[:-1]

    def rpc_print_products(self):
        with self.channel_workers.access() as workers:
            channel_keys = workers.keys()
            if not channel_keys:
                return "No channels are currently active.\n"

            width = max(len(x) for x in channel_keys) + 1
            txt = ""
            for ch, worker in workers.items():
                if not worker.is_alive():
                    txt += f"Channel {ch} is in ERROR state\n"
                    continue

                txt += f"channel: {ch:<{width}}, id = {worker.task_manager.id:<{width}}, state = {worker.get_state_name():<10} \n"
                tm = self.dataspace.get_taskmanager(ch)
                data_block = datablock.DataBlock(
                    self.dataspace,
                    ch,
                    taskmanager_id=tm["taskmanager_id"],
                    sequence_id=tm["sequence_id"])
                data_block.generation_id -= 1
                channel_config = self.channel_config_loader.get_channels()[ch]
                produces = worker.get_produces()
                for i in ("sources", "transforms", "logicengines",
                          "publishers"):
                    txt += f"\t{i}:\n"
                    modules = channel_config.get(i, {})
                    for mod_name in modules.keys():
                        txt += f"\t\t{mod_name}\n"
                        products = produces.get(mod_name, [])
                        for product in products:
                            try:
                                df = data_block[product]
                                df = pd.read_json(df.to_json())
                                txt += f"{tabulate.tabulate(df, headers='keys', tablefmt='psql')}\n"
                            except Exception as e:  # pragma: no cover
                                txt += f"\t\t\t{e}\n"
        return txt[:-1]

    @STATUS_HISTOGRAM.time()
    def rpc_status(self):
        with self.channel_workers.access() as workers:
            channel_keys = workers.keys()
            if not channel_keys:
                return "No channels are currently active.\n" + self.reaper_status(
                )

            txt = ""
            width = max(len(x) for x in channel_keys) + 1
            for ch, worker in workers.items():
                txt += f"channel: {ch:<{width}}, id = {worker.task_manager.id:<{width}}, state = {worker.get_state_name():<10} \n"
                produces = worker.get_produces()
                consumes = worker.get_consumes()
                channel_config = self.channel_config_loader.get_channels()[ch]
                for i in ("sources", "transforms", "logicengines",
                          "publishers"):
                    txt += f"\t{i}:\n"
                    modules = channel_config.get(i, {})
                    for mod_name in modules.keys():
                        txt += f"\t\t{mod_name}\n"
                        txt += f"\t\t\tconsumes : {consumes.get(mod_name, [])}\n"
                        txt += f"\t\t\tproduces : {produces.get(mod_name, [])}\n"
        return txt + self.reaper_status()

    def rpc_queue_status(self):
        status = redis_stats(self.broker_url, self.exchange.name)
        return f"\n{tabulate.tabulate(status, headers=['Source name', 'Queue name', 'Unconsumed messages'])}"

    def rpc_stop(self):
        self.shutdown()
        self.stop_channels()
        self.reaper_stop()
        self.dataspace.close()

        if not self.global_config.get("no_webserver"):
            cherrypy.engine.exit()

        de_logger.stop_queue_logger()
        return "OK"

    def start_channel(self, channel_name, channel_config):
        channel_config = copy.deepcopy(channel_config)
        with START_CHANNEL_HISTOGRAM.labels(channel_name).time():
            # NB: Possibly override channel name
            channel_name = channel_config.get("channel_name", channel_name)
            source_configs = channel_config.pop("sources")
            src_workers = self.source_workers.update(channel_name,
                                                     source_configs)
            module_workers = validated_workflow(channel_name, src_workers,
                                                channel_config, self.logger)

            queue_info = [(worker.queue.name, worker.key)
                          for worker in src_workers.values()]
            self.logger.debug(f"Building TaskManger for {channel_name}")
            task_manager = TaskManager.TaskManager(
                channel_name,
                module_workers,
                dataspace.DataSpace(self.global_config),
                source_products(src_workers),
                self.exchange,
                self.broker_url,
                queue_info,
            )
            self.logger.debug(f"Building Worker for {channel_name}")
            worker = ChannelWorker(task_manager, self.global_config["logger"])
            WORKERS_COUNT.inc()
            with self.channel_workers.access() as workers:
                workers[channel_name] = worker

            # The channel must be started first so it can listen for the messages from the sources.
            self.logger.debug(f"Trying to start {channel_name}")
            worker.start()
            self.logger.info(f"Channel {channel_name} started")

            worker.wait_while(ProcessingState.State.BOOT)

            # Start any sources that are not yet alive.
            for key, src_worker in src_workers.items():
                if src_worker.is_alive():
                    continue
                if src_worker.exitcode == 0:  # pragma: no cover
                    # This can happen if the source's acquire method runs only once (e.g. when testing)
                    # and the first process completes before the next channel can use it.
                    raise RuntimeError(
                        f"The {key} source has already completed and cannot be used by channel {channel_name}."
                    )

                src_worker.start()
                self.logger.debug(
                    f"Started process {src_worker.pid} for source {key}")

            worker.wait_while(ProcessingState.State.ACTIVE)

    def start_channels(self):
        self.channel_config_loader.load_all_channels()

        if not self.channel_config_loader.get_channels():
            self.logger.info(
                "No channel configurations available in " +
                f"{self.channel_config_loader.channel_config_dir}")
        else:
            self.logger.debug(
                f"Found channels: {self.channel_config_loader.get_channels().items()}"
            )

        # FIXME: Should figure out a way to load the channels in parallel.  Unfortunately, there are data races that
        #        occur when doing that (observed with Python 3.10).
        for name, config in self.channel_config_loader.get_channels().items():
            try:
                self.start_channel(name, config)
            except Exception as e:
                self.logger.exception(f"Channel {name} failed to start: {e}")

    def rpc_start_channel(self, channel_name):
        with self.channel_workers.access() as workers:
            if channel_name in workers:
                return f"ERROR, channel {channel_name} is running"

        success, result = self.channel_config_loader.load_channel(channel_name)
        if not success:
            return result
        self.start_channel(channel_name, result)
        return "OK"

    def rpc_start_channels(self):
        self.start_channels()
        return "OK"

    def rpc_stop_channel(self, channel):
        return self.rpc_rm_channel(channel, None)

    def rpc_kill_channel(self, channel, timeout=None):
        if timeout is None:
            timeout = self.global_config.get("shutdown_timeout", 10)
        return self.rpc_rm_channel(channel, timeout)

    def rpc_rm_channel(self, channel, maybe_timeout):
        rc = self.rm_channel(channel, maybe_timeout)
        if rc == StopState.NotFound:
            return f"No channel found with the name {channel}."
        elif rc == StopState.Terminated:
            if maybe_timeout == 0:
                return f"Channel {channel} has been killed."
            # Would be better to use something like the inflect
            # module, but that introduces another dependency.
            suffix = "s" if maybe_timeout > 1 else ""
            return f"Channel {channel} has been killed due to shutdown timeout ({maybe_timeout} second{suffix})."
        assert rc == StopState.Clean
        WORKERS_COUNT.dec()
        return f"Channel {channel} stopped cleanly."

    def rm_channel(self, channel, maybe_timeout):
        with RM_CHANNEL_HISTOGRAM.labels(channel).time():
            rc = None
            with self.channel_workers.access() as workers:
                worker = workers.get(channel)
                if worker is None:
                    return StopState.NotFound
                sources_to_prune = worker.task_manager.routing_keys
                self.logger.debug(f"Trying to stop {channel}")
                rc = self.stop_worker(worker, maybe_timeout)
                del workers[channel]
                self.logger.debug(f"Channel {channel} removed ({rc})")
                self.source_workers.prune(sources_to_prune)
            return rc

    def stop_worker(self, worker, timeout):
        if worker.is_alive():
            self.logger.debug("Trying to take worker offline")
            worker.task_manager.take_offline()
            worker.join(timeout)
        if worker.exitcode is None:
            worker.terminate()
            return StopState.Terminated
        else:
            return StopState.Clean

    def stop_channels(self):
        timeout = self.global_config.get("shutdown_timeout", 10)
        with self.channel_workers.access() as workers:
            for worker in workers.values():
                self.stop_worker(worker, timeout)
            workers.clear()
        self.source_workers.remove_all(timeout)

    def rpc_stop_channels(self):
        self.stop_channels()
        return "All channels stopped."

    def handle_sighup(self, signum, frame):
        self.reaper_stop()
        self.stop_channels()
        self.start_channels()
        self.reaper_start(delay=self.global_config["dataspace"].get(
            "reaper_start_delay_seconds", 1818))

    def rpc_get_log_level(self):
        engineloglevel = self.get_logger().getEffectiveLevel()
        return logging.getLevelName(engineloglevel)

    def rpc_get_channel_log_level(self, channel):
        with self.channel_workers.access() as workers:
            worker = workers.get(channel)
            if worker is None:
                return f"No channel found with the name {channel}."

            if not worker.is_alive():
                return f"Channel {channel} is in ERROR state."
            return logging.getLevelName(worker.task_manager.get_loglevel())

    def rpc_set_channel_log_level(self, channel, log_level):
        """Assumes log_level is a string corresponding to the supported logging-module levels."""
        with self.channel_workers.access() as workers:
            worker = workers.get(channel)
            if worker is None:
                return f"No channel found with the name {channel}."

            if not worker.is_alive():
                return f"Channel {channel} is in ERROR state."

            log_level_code = getattr(logging, log_level)
            if worker.task_manager.get_loglevel() == log_level_code:
                return f"Nothing to do. Current log level is : {log_level}"
            worker.task_manager.set_loglevel_value(log_level)
        return f"Log level changed to : {log_level}"

    def rpc_reaper_start(self, delay=0):
        """
        Start the reaper process after 'delay' seconds.
        Default 0 seconds delay.
        :type delay: int
        """
        self.reaper_start(delay)
        return "OK"

    def reaper_start(self, delay):
        self.reaper.start(delay)

    def rpc_reaper_stop(self):
        self.reaper_stop()
        return "OK"

    def reaper_stop(self):
        self.reaper.stop()

    def rpc_reaper_status(self):
        interval = self.reaper.retention_interval
        state = self.reaper.state.get()
        return f"reaper:\n\tstate: {state}\n\tretention_interval: {interval}"

    def reaper_status(self):
        interval = self.reaper.retention_interval
        state = self.reaper.state.get()
        return f"\nreaper:\n\tstate: {state}\n\tretention_interval: {interval}\n"

    def rpc_query_tool(self, product, format=None, start_time=None):
        with QUERY_TOOL_HISTOGRAM.labels(product).time():
            found = False
            result = pd.DataFrame()
            txt = f"Product {product}: "

            with self.channel_workers.access() as workers:
                for ch, worker in workers.items():
                    if not worker.is_alive():
                        txt += f"Channel {ch} is in not active\n"
                        continue

                    produces = worker.get_produces()
                    r = [x for x in list(produces.items()) if product in x[1]]
                    if not r:
                        continue
                    found = True
                    txt += f" Found in channel {ch}\n"

                    if start_time:
                        tms = self.dataspace.get_taskmanagers(
                            ch, start_time=start_time)
                    else:
                        tms = [self.dataspace.get_taskmanager(ch)]
                    for tm in tms:
                        try:
                            data_block = datablock.DataBlock(
                                self.dataspace,
                                ch,
                                taskmanager_id=tm["taskmanager_id"],
                                sequence_id=tm["sequence_id"])
                            products = data_block.get_dataproducts(product)
                            for p in products:
                                df = p["value"]
                                if df.shape[0] > 0:
                                    df["channel"] = [tm["name"]] * df.shape[0]
                                    df["taskmanager_id"] = [
                                        p["taskmanager_id"]
                                    ] * df.shape[0]
                                    df["generation_id"] = [p["generation_id"]
                                                           ] * df.shape[0]
                                    result = result.append(df)
                        except Exception as e:  # pragma: no cover
                            txt += f"\t\t{e}\n"

            if found:
                dataframe_formatter = self._dataframe_to_table
                if format == "csv":
                    dataframe_formatter = self._dataframe_to_csv
                if format == "json":
                    dataframe_formatter = self._dataframe_to_json
                result = result.reset_index(drop=True)
                txt += dataframe_formatter(result)
            else:
                txt += "Not produced by any module\n"
            return txt

    def start_webserver(self):
        """
        Start CherryPy webserver using configured port.  If port is not configured
        use default webserver port.
        """
        _socket_host = "0.0.0.0"
        if self.global_config.get("webserver") and isinstance(
                self.global_config.get("webserver"), dict):
            _port = self.global_config["webserver"].get(
                "port", DEFAULT_WEBSERVER_PORT)
        else:  # pragma: no cover
            # unit tests use a random port
            _port = DEFAULT_WEBSERVER_PORT

        with contextlib.suppress(Exception):
            self.logger.debug(
                f"Trying to start metrics server on {_socket_host}:{_port}")

        cherrypy.config.update({
            "server.socket_port": _port,
            "server.socket_host": _socket_host,
            "server.shutdown_timeout": 1
        })
        cherrypy.engine.signals.subscribe()
        cherrypy.tree.mount(self)
        # we know for sure the cherrypy logger is working, so use that too
        cherrypy.log(
            f"Trying to start metrics server on {_socket_host}:{_port}")
        cherrypy.engine.start()
        with contextlib.suppress(Exception):
            self.logger.debug("Started CherryPy server")

    @cherrypy.expose
    def metrics(self):
        return self.rpc_metrics()

    @METRICS_HISTOGRAM.time()
    def rpc_metrics(self):
        """
        Display collected metrics
        """
        try:
            return display_metrics()
        except Exception as e:  # pragma: no cover
            self.logger.error(e)