def test_jitter(self):
        b = Backoff(min_ms=100.0, max_ms=10000.0, factor=2.0, jitter=True)

        self.assertEqual(b.duration(), to_seconds(100.0))
        self.assert_between(b.duration(), to_seconds(100.0), to_seconds(200.0))
        self.assert_between(b.duration(), to_seconds(100.0), to_seconds(400.0))
        b.reset()
        self.assertEqual(b.duration(), to_seconds(100.0))
    def test_integers(self):
        b = Backoff(min_ms=100, max_ms=10000, factor=2)

        self.assertEqual(b.duration(), to_seconds(100.0))
        self.assertEqual(b.duration(), to_seconds(200.0))
        self.assertEqual(b.duration(), to_seconds(400.0))
        b.reset()
        self.assertEqual(b.duration(), to_seconds(100.0))
    def test_factor(self):
        b = Backoff(min_ms=100, max_ms=10000, factor=1.5)

        self.assertEqual(b.duration(), to_seconds(100.0))
        self.assertEqual(b.duration(), to_seconds(150.0))
        self.assertEqual(b.duration(), to_seconds(225.0))
        b.reset()
        self.assertEqual(b.duration(), to_seconds(100.0))
Beispiel #4
0
    def test_min_bigger_than_max(self):
        b = Backoff(min_ms=10000.0, max_ms=1000.0, factor=2)

        self.assertEqual(b.duration(), 1.0)
        self.assertEqual(b.duration(), 1.0)
        self.assertEqual(b.duration(), 1.0)
        b.reset()
        self.assertEqual(b.duration(), 1.0)
class TestBackoff(unittest.TestCase, CustomAssertions):
    def setUp(self):
        self.b = Backoff(min_ms=100.0, max_ms=10000.0, factor=2.0)

    def test_defaults(self):
        self.assertEqual(self.b.duration(), to_seconds(100.0))
        self.assertEqual(self.b.duration(), to_seconds(200.0))
        self.assertEqual(self.b.duration(), to_seconds(400.0))
        self.b.reset()
        self.assertEqual(self.b.duration(), to_seconds(100.0))

    def test_factor(self):
        b = Backoff(min_ms=100, max_ms=10000, factor=1.5)

        self.assertEqual(b.duration(), to_seconds(100.0))
        self.assertEqual(b.duration(), to_seconds(150.0))
        self.assertEqual(b.duration(), to_seconds(225.0))
        b.reset()
        self.assertEqual(b.duration(), to_seconds(100.0))

    def test_for_attempt(self):
        self.assertEqual(self.b.for_attempt(0), to_seconds(100.0))
        self.assertEqual(self.b.for_attempt(1), to_seconds(200.0))
        self.assertEqual(self.b.for_attempt(2), to_seconds(400.0))
        self.b.reset()
        self.assertEqual(self.b.for_attempt(0), to_seconds(100.0))

    def test_get_attempt(self):
        self.assertEqual(self.b.attempt(), 0)
        self.assertEqual(self.b.duration(), to_seconds(100.0))
        self.assertEqual(self.b.attempt(), 1)
        self.assertEqual(self.b.duration(), to_seconds(200.0))
        self.assertEqual(self.b.attempt(), 2)
        self.assertEqual(self.b.duration(), to_seconds(400.0))
        self.assertEqual(self.b.attempt(), 3)
        self.b.reset()
        self.assertEqual(self.b.attempt(), 0)
        self.assertEqual(self.b.duration(), to_seconds(100.0))
        self.assertEqual(self.b.attempt(), 1)

    def test_jitter(self):
        b = Backoff(min_ms=100.0, max_ms=10000.0, factor=2.0, jitter=True)

        self.assertEqual(b.duration(), to_seconds(100.0))
        self.assert_between(b.duration(), to_seconds(100.0), to_seconds(200.0))
        self.assert_between(b.duration(), to_seconds(100.0), to_seconds(400.0))
        b.reset()
        self.assertEqual(b.duration(), to_seconds(100.0))

    def test_integers(self):
        b = Backoff(min_ms=100, max_ms=10000, factor=2)

        self.assertEqual(b.duration(), to_seconds(100.0))
        self.assertEqual(b.duration(), to_seconds(200.0))
        self.assertEqual(b.duration(), to_seconds(400.0))
        b.reset()
        self.assertEqual(b.duration(), to_seconds(100.0))
Beispiel #6
0
class Worker(threading.Thread):
    def __init__(self, name, logger):
        super().__init__()
        self.name = name
        self.logger = logger.bind(worker=name)
        self.exit = False
        self.event = threading.Event()
        self.backoff = Backoff(min_ms=100,
                               max_ms=30000,
                               factor=2,
                               jitter=False)
        self.lastWorked = time.time()

    def run(self):
        self.logger.info('Worker started')
        self.setup()
        while not self.exit:
            worked = False
            try:
                worked = self.work()
            except Exception as e:
                self.logger.warn('Worker job failed', exc_info=e)

            if worked:
                self.backoff.reset()
                self.lastWorked = time.time()
            else:
                #self.logger.debug(f'Worker has nothing to do')
                pass
            self.event.wait(self.backoff.duration())
            self.event.clear()

        self.teardown()
        self.logger.info('Worker stopped')

    def stop(self):
        self.exit = True
        self.event.set()

    def setup(self):
        pass

    def teardown(self):
        pass

    def work(self):
        return False

    def idle(self):
        return time.time() - self.lastWorked
class AsyncParticipant(threading.Thread):
    def __init__(
        self,
        coordinator_url: str,
        notifier,
        state,
        scalar,
    ):
        # xaynet rust participant
        self._xaynet_participant = xaynet_sdk.Participant(
            coordinator_url, scalar, state)

        self._exit_event = threading.Event()
        self._poll_period = Backoff(min_ms=100,
                                    max_ms=10000,
                                    factor=1.2,
                                    jitter=False)

        # new global model notifier
        self._notifier = notifier

        # calls to an external lib are thread-safe https://stackoverflow.com/a/42023362
        # however, if a user calls `stop` in the middle of the `_tick` call, the
        # `save` method will be executed (which consumes the participant) and every following call
        # will fail with a call on an uninitialized participant. Therefore we lock during `tick`.
        self._tick_lock = threading.Lock()

        super().__init__(daemon=True)

    def run(self):
        try:
            self._run()
        except Exception as err:  # pylint: disable=broad-except
            LOG.error("unrecoverable error: %s shut down participant", err)
            self._exit_event.set()

    def _notify(self):
        if self._notifier.is_set() is False:
            LOG.debug("notify that a new global model is available")
            self._notifier.set()

    def _run(self):
        while not self._exit_event.is_set():
            self._tick()

    def _tick(self):
        with self._tick_lock:
            self._xaynet_participant.tick()
            new_global_model = self._xaynet_participant.new_global_model()
            made_progress = self._xaynet_participant.made_progress()

        if new_global_model:
            self._notify()

        if made_progress:
            self._poll_period.reset()
            self._exit_event.wait(timeout=self._poll_period.duration())
        else:
            self._exit_event.wait(timeout=self._poll_period.duration())

    def get_global_model(self) -> Optional[list]:
        """
        Fetches the current global model. This method can be called at any time. If no global
        model exists (usually in the first round), the method returns `None`.

        Returns:
            The current global model in the form of a list or `None`. The data type of the
            elements match the data type defined in the coordinator configuration.

        Raises:
            GlobalModelUnavailable: If the participant cannot connect to the coordinator to get
                the global model.
            GlobalModelDataTypeMisMatch: If the data type of the global model does not match
                the data type defined in the coordinator configuration.
        """
        LOG.debug("get global model")
        self._notifier.clear()
        with self._tick_lock:
            return self._xaynet_participant.global_model()

    def set_local_model(self, local_model: list):
        """
        Sets a local model. This method can be called at any time. Internally the
        participant first caches the local model. As soon as the participant is selected as the
        update participant, the currently cached local model is used. This means that the cache
        is empty after this operation.

        If a local model is already in the cache and `set_local_model` is called with a new local
        model, the current cached local model will be replaced by the new one.
        If the participant is an update participant and there is no local model in the cache,
        the participant waits until a local model is set or until a new round has been started.

        Args:
            local_model: The local model in the form of a list. The data type of the
                elements must match the data type defined in the coordinator configuration.

        Raises:
            LocalModelLengthMisMatch: If the length of the local model does not match the
                length defined in the coordinator configuration.
            LocalModelDataTypeMisMatch: If the data type of the local model does not match
                the data type defined in the coordinator configuration.
        """
        LOG.debug("set local model in model store")
        with self._tick_lock:
            self._xaynet_participant.set_model(local_model)

    def stop(self) -> List[int]:
        """
        Stops the execution of the participant and returns its serialized state.
        The serialized state can be passed to the `spawn_async_participant` function
        to restore a participant.

        After calling `stop`, the participant is consumed. Every further method
        call on the handle of `AsyncParticipant` leads to an `UninitializedParticipant`
        exception.

        Note:
            The serialized state contains unencrypted **private key(s)**. If used
            in production, it is important that the serialized state is securely saved.

        Returns:
            The serialized state of the participant.
        """
        LOG.debug("stop participant")
        self._exit_event.set()
        self._notifier.clear()
        with self._tick_lock:
            return self._xaynet_participant.save()
Beispiel #8
0
class InternalParticipant(threading.Thread):
    def __init__(
        self,
        coordinator_url: str,
        participant,
        p_args,
        p_kwargs,
        state,
        scalar,
    ):
        # xaynet rust participant
        self._xaynet_participant = xaynet_sdk.Participant(
            coordinator_url, scalar, state)

        # https://github.com/python/cpython/blob/3.9/Lib/multiprocessing/process.py#L80
        # stores the Participant class with its args and kwargs
        # the participant is created in the `run` method to ensure that the participant/ ml
        # model is initialized on the participant thread otherwise the participant lives on the main
        # thread which can created issues with some of the ml frameworks.
        self._participant = participant
        self._p_args = tuple(p_args)
        self._p_kwargs = dict(p_kwargs)

        self._exit_event = threading.Event()
        self._poll_period = Backoff(min_ms=100,
                                    max_ms=10000,
                                    factor=1.2,
                                    jitter=False)

        # global model cache
        self._global_model = None
        self._error_on_fetch_global_model = False

        self._tick_lock = threading.Lock()

        super().__init__(daemon=True)

    def run(self):
        self._participant = self._participant(*self._p_args, *self._p_kwargs)

        try:
            self._run()
        except Exception as err:  # pylint: disable=broad-except
            LOG.error("unrecoverable error: %s shut down participant", err)
            self._exit_event.set()

    def _fetch_global_model(self):
        LOG.debug("fetch global model")
        try:
            global_model = self._xaynet_participant.global_model()
        except (
                xaynet_sdk.GlobalModelUnavailable,
                xaynet_sdk.GlobalModelDataTypeMisMatch,
        ) as err:
            LOG.warning("failed to get global model: %s", err)
            self._error_on_fetch_global_model = True
        else:
            if global_model is not None:
                self._global_model = self._participant.deserialize_training_input(
                    global_model)
            else:
                self._global_model = None
            self._error_on_fetch_global_model = False

    def _train(self):
        LOG.debug("train model")
        data = self._participant.train_round(self._global_model)
        local_model = self._participant.serialize_training_result(data)
        try:
            self._xaynet_participant.set_model(local_model)
        except (
                xaynet_sdk.LocalModelLengthMisMatch,
                xaynet_sdk.LocalModelDataTypeMisMatch,
        ) as err:
            LOG.warning("failed to set local model: %s", err)

    def _run(self):
        while not self._exit_event.is_set():
            self._tick()

    def _tick(self):
        with self._tick_lock:
            self._xaynet_participant.tick()

            if (self._xaynet_participant.new_global_model()
                    or self._error_on_fetch_global_model):
                self._fetch_global_model()

                if not self._error_on_fetch_global_model:
                    self._participant.on_new_global_model(self._global_model)

            if (self._xaynet_participant.should_set_model()
                    and self._participant.participate_in_update_task()
                    and not self._error_on_fetch_global_model):
                self._train()

            made_progress = self._xaynet_participant.made_progress()

        if made_progress:
            self._poll_period.reset()
            self._exit_event.wait(timeout=self._poll_period.duration())
        else:
            self._exit_event.wait(timeout=self._poll_period.duration())

    def stop(self) -> List[int]:
        """
        Stops the execution of the participant and returns its serialized state.
        The serialized state can be passed to the `spawn_participant` function
        to restore a participant.

        After calling `stop`, the participant is consumed. Every further method
        call on the handle of `InternalParticipant` leads to an `UninitializedParticipant`
        exception.

        Note:
            The serialized state contains unencrypted **private key(s)**. If used
            in production, it is important that the serialized state is securely saved.

        Returns:
            The serialized state of the participant.
        """
        LOG.debug("stopping participant")
        self._exit_event.set()
        with self._tick_lock:
            state = self._xaynet_participant.save()
            LOG.debug("participant stopped")
        self._participant.on_stop()
        return state
def cli(ctx,
        workflow_id,
        invocation_id,
        exit_early=False,
        backoff_min=1,
        backoff_max=60):
    """Given a workflow and invocation id, wait until that invocation is
    complete (or one or more steps have errored)

    This will exit with the following error codes:

    - 0: done successfully
    - 1: running (if --exit_early)
    - 2: failure
    - 3: unknown
    """
    backoff = Backoff(min_ms=backoff_min * 1000,
                      max_ms=backoff_max * 1000,
                      factor=2,
                      jitter=True)

    prev_state = None
    while True:
        # Fetch the current state
        latest_state = ctx.gi.workflows.show_invocation(
            workflow_id, invocation_id)
        # Get step states
        states = [
            step['state'] for step in latest_state['steps']
            if step['state'] is not None and step['state'] != 'deleted'
        ]
        # Get a str based state representation
        state_rep = '|'.join(map(str, states))

        if state_rep != prev_state:
            backoff.reset()
        prev_state = state_rep

        # If it's scheduled, then let's look at steps. Otherwise steps probably don't exist yet.
        if latest_state['state'] == 'scheduled':
            ctx.vlog("Checking workflow %s states: %s", workflow_id, state_rep)

            if exit_early:
                print(json.dumps({'state': 'running', 'job_states': states}))
                ctx.exit(1)

            # Conditions which must be true for all jobs before we can be done
            if all([state == 'ok' for state in states]):
                print(json.dumps({'state': 'done', 'job_states': states}))
                ctx.exit(0)

            # Conditions on which to exit immediately (i.e. due to a failure)
            if any([state in ('error', 'paused') for state in states]):
                print(json.dumps({'state': 'failure', 'job_states': states}))
                ctx.exit(2)
        else:
            ctx.vlog("Waiting for invocation to be scheduled")

            if exit_early:
                print(json.dumps({'state': 'unscheduled'}))
                ctx.exit(0)

        time.sleep(backoff.duration())
    ctx.exit(3)
Beispiel #10
0
class Client:
    """
    :param update_interval: Frequency (in seconds) to trigger a full state
        refresh
    :param infer_arming_state: Infer the `DISARMED` arming state only via
        system status events. This works around a bug with some panels
        (`<v5.8`) which emit `update.status = []` when they are armed.
    """
    def __init__(self,
                 connection: Optional[Connection] = None,
                 host: Optional[str] = None,
                 port: Optional[int] = None,
                 loop: Optional[asyncio.AbstractEventLoop] = None,
                 update_interval: int = 60,
                 infer_arming_state: bool = False,
                 alarm: Optional[Alarm] = None):
        if connection is None:
            assert host is not None
            assert port is not None
            assert loop is not None
            connection = IP232Connection(host=host, port=port, loop=loop)

        if alarm is None:
            alarm = Alarm(infer_arming_state=infer_arming_state)

        self.alarm = alarm
        self._on_event_received: Optional[Callable[[BaseEvent], None]] = None
        self._connection = connection
        self._closed = False
        self._backoff = Backoff()
        self._connect_lock = asyncio.Lock()
        self._last_recv: Optional[datetime.datetime] = None
        self._update_interval = update_interval

    async def arm_away(self, code: Optional[str] = None) -> None:
        command = 'A{}E'.format(code if code else '')
        return await self.send_command(command)

    async def arm_home(self, code: Optional[str] = None) -> None:
        command = 'H{}E'.format(code if code else '')
        return await self.send_command(command)

    async def disarm(self, code: str) -> None:
        command = '{}E'.format(code)
        return await self.send_command(command)

    async def panic(self, code: str) -> None:
        command = '*{}#'.format(code)
        return await self.send_command(command)

    async def aux(self, output_id: int, state: bool = True) -> None:
        command = '{}{}{}'.format(
            output_id, output_id,
            '*' if state else '#')
        return await self.send_command(command)

    async def update(self) -> None:
        """Force update of alarm status and zones"""
        _LOGGER.debug("Requesting state update from server (S00, S14)")
        await asyncio.gather(
            # List unsealed Zones
            self.send_command('S00'),
            # Arming status update
            self.send_command('S14'),
        )

    async def _connect(self) -> None:
        async with self._connect_lock:
            if self._should_reconnect():
                _LOGGER.debug('Closing stale connection and reconnecting')
                await self._connection.close()

            while not self._connection.connected:
                _LOGGER.debug('Attempting to connect')
                try:
                    await self._connection.connect()
                except (ConnectionRefusedError, OSError) as e:
                    _LOGGER.warning('Failed to connect: %s', e)
                    await sleep(self._backoff.duration())

                self._last_recv = datetime.datetime.now()

            self._backoff.reset()

    async def send_command(self, command: str) -> None:
        packet = Packet(
            address=0x00,
            seq=0x00,
            command=CommandType.USER_INTERFACE,
            data=command,
            timestamp=None,
        )
        await self._connect()
        payload = packet.encode() + '\r\n'
        _LOGGER.debug('Sending payload: %s', repr(payload))
        return await self._connection.write(payload.encode('ascii'))

    async def _recv_loop(self) -> None:
        while not self._closed:
            await self._connect()

            while True:
                data = await self._connection.read()
                if data is None:
                    _LOGGER.debug("Received None data from connection.read()")
                    break

                self._last_recv = datetime.datetime.now()
                try:
                    decoded_data = data.decode('utf-8').strip()
                except UnicodeDecodeError:
                    _LOGGER.warning("Failed to decode data", exc_info=True)
                    continue

                _LOGGER.debug("Decoding data: '%s'", decoded_data)
                if len(decoded_data) > 0:
                    try:
                        pkt = Packet.decode(decoded_data)
                        event = BaseEvent.decode(pkt)
                    except Exception:
                        _LOGGER.warning("Failed to decode packet", exc_info=True)
                        continue

                    if self._on_event_received is not None:
                        self._on_event_received(event)

                    self.alarm.handle_event(event)

    def _should_reconnect(self) -> bool:
        now = datetime.datetime.now()
        return self._last_recv is not None and self._last_recv < now - datetime.timedelta(
            seconds=self._update_interval + 30)

    async def _update_loop(self) -> None:
        """Schedule a state update to keep the connection alive"""
        await asyncio.sleep(self._update_interval)
        while not self._closed:
            await self.update()
            await asyncio.sleep(self._update_interval)

    async def keepalive(self) -> None:
        await asyncio.gather(
            self._recv_loop(),
            self._update_loop(),
        )

    async def close(self) -> None:
        self._closed = True
        await self._connection.close()

    def on_state_change(self, f: Callable[[ArmingState], None]
                        ) -> Callable[[ArmingState], None]:
        self.alarm.on_state_change(f)
        return f

    def on_zone_change(self, f: Callable[[int, bool], None]
                       ) -> Callable[[int, bool], None]:
        self.alarm.on_zone_change(f)
        return f

    def on_event_received(self, f: Callable[[BaseEvent], None]
                          ) -> Callable[[BaseEvent], None]:
        self._on_event_received = f
        return f