def test_first(): x = [1, 2, 3] assert util.first(x) == 1 assert util.first([]) is None assert util.first(x, lambda x: x > 1) == 2 assert util.first(x, lambda x: x > 3) is None assert util.first([], default=4) == 4
def create(): zip_path.parent.mkdir(parents=True, exist_ok=True) subprocess.run([ 'wget', '--show-progress', '-q', '-c', '-O', str(zip_path), str(win10_url) ], check=True) with _create_tmpdir() as tmpdir: subprocess.run( ['unzip', '-q', str(zip_path.resolve())], cwd=str(tmpdir), check=True) ova_path = util.first(tmpdir.glob('*.ova')) subprocess.run( ['tar', '-x', '-f', str(ova_path.resolve())], cwd=str(tmpdir), check=True) vmdk_path = util.first(tmpdir.glob('*.vmdk')) subprocess.run([ 'qemu-img', 'convert', '-f', 'vmdk', '-O', 'qcow2', str(vmdk_path), str(img_path) ], check=True)
def wait_active_server(servers, timeout): start = time.monotonic() while time.monotonic() - start < timeout: srv = util.first(servers, lambda srv: srv.is_active()) if srv: return srv time.sleep(0.01)
def _ext_get_view(name, view_path, conf_path): data = {} for i in view_path.rglob('*'): if i.is_dir(): continue if i.suffix in {'.js', '.css', '.txt'}: with open(i, encoding='utf-8') as f: content = f.read() elif i.suffix in {'.json', '.yaml', '.yml'}: content = json.decode_file(i) elif i.suffix in {'.xml', '.svg'}: with open(i, encoding='utf-8') as f: content = vt.parse(f) else: with open(i, 'rb') as f: content = f.read() content = base64.b64encode(content).decode('utf-8') file_name = i.relative_to(view_path).as_posix() data[file_name] = content conf = json.decode_file(conf_path) if conf_path else None schema = util.first(v for k, v in data.items() if k in {'schema.json', 'schema.yaml', 'schema.yml'}) if schema: repo = json.SchemaRepository(schema) repo.validate(schema['id'], conf) return View(name=name, conf=conf, data=data)
async def test_backend_to_frontend(backend, web_server, client, create_msg): client_change_queue = aio.Queue() client.register_change_cb(lambda: client_change_queue.put_nowait(None)) entry_queue = aio.Queue() backend.register_change_cb(entry_queue.put_nowait) await client_change_queue.get() assert_client_vs_server_state(client) entries = [] for _ in range(10): msg = create_msg() await backend.register(ts_now(), msg) reg_entries = await entry_queue.get() entry = reg_entries[0] entries.insert(0, entry) await client_change_queue.get() assert entry in client.server_state['entries'] assert util.first(client.server_state['entries'], lambda i: i.msg == msg) assert entries == client.server_state['entries'] assert client.server_state['first_id'] == 1 assert client.server_state['last_id'] == len(entries) assert_client_vs_server_state(client)
def sanitize(t): if isinstance(t, common.PrefixedType): return t._replace(type=sanitize(t.type), implicit=(t.implicit if t.implicit is not None else default_implicit)) if (isinstance(t, common.SetOfType) or isinstance(t, common.SequenceOfType)): return t._replace(type=sanitize(t.type)) if isinstance(t, common.SetType) or isinstance(t, common.SequenceType): return t._replace(elements=[ i._replace(type=sanitize(i.type)) for i in t.elements ]) if isinstance(t, common.ChoiceType): return t._replace( choices=[i._replace(type=sanitize(i.type)) for i in t.choices]) if isinstance(t, common.TypeRef): if t == common.TypeRef('ABSTRACT-SYNTAX', 'Type'): return common.EntityType() if t.module: return t if t not in types: import_ref = util.first(imports, lambda i: i.name == t.name) if import_ref: return t._replace(module=import_ref.module) return t._replace(module=name) return t
async def test_event_register(event_client_factory, adapter_factory): client = await event_client_factory([['*']]) adapters_conf = [{'name': 'adapter1', 'module': 'test_unit.test_gui.mock'}] async with adapter_factory(adapters_conf) as adapters: adapter = adapters['adapter1'] register_events = [ hat.event.common.RegisterEvent( event_type=['hat', 'gui', 'mock', 'system'], source_timestamp=hat.event.common.now(), payload=hat.event.common.EventPayload( hat.event.common.EventPayloadType.JSON, data={'abc': 'def'})), hat.event.common.RegisterEvent( event_type=['hat', 'gui', 'mock'], source_timestamp=None, payload=hat.event.common.EventPayload( hat.event.common.EventPayloadType.JSON, data={'abc': 'def'})), hat.event.common.RegisterEvent( event_type=['should', 'register'], source_timestamp=None, payload=hat.event.common.EventPayload( hat.event.common.EventPayloadType.JSON, data={'abc': 'def'}))] adapter.client.register(register_events) events = await client.receive() assert len(register_events) == len(events) for ev in events: assert util.first(register_events, lambda reg_ev: ( ev.event_type == reg_ev.event_type and ev.source_timestamp == reg_ev.source_timestamp and ev.payload == reg_ev.payload))
async def test_ui_incorrect_cid_mid(cluster_factory): cluster = await cluster_factory({'group': {'components': ['c1', 'c2']}}) incorrect_ui_client = await hat.juggler.connect( f'ws://127.0.0.1:{cluster.server_info.ui_port}/ws') await asyncio.sleep(0.5) try: await incorrect_ui_client.send({ 'type': 'set_rank', 'payload': { 'cid': 150, 'rank': 2 } }) server_info = cluster.server_info assert common.process_is_running(server_info.process) connections = server_info.process.connections() for port in { server_info.ui_port, server_info.monitor_port, server_info.master_port }: assert util.first( connections, lambda c: (c.laddr.ip == '0.0.0.0' and c.laddr.port == port)) finally: await incorrect_ui_client.async_close()
def test_master_slave(monitor_factory): slave = monitor_factory() master = monitor_factory(parent_infos=[slave]) assert util.first( master.process.connections(), lambda conn: (conn.raddr.port == slave.master_port and conn.raddr.ip == '127.0.0.1' if conn.raddr else False))
def main(): """Main""" args = _create_parser().parse_args() json_schema_repo = json.SchemaRepository( json.json_schema_repo, *args.additional_json_schemas_paths) translators = [] for module in itertools.chain(builtin_translators, args.module_names): translators += importlib.import_module(module).translators format = {'yaml': json.Format.YAML, 'json': json.Format.JSON}[args.format] if args.action == 'list': output = [_translator_to_json(trans) for trans in translators] elif args.action == 'translate': trans = util.first( translators[::-1], lambda i: i.input_type == args.input_type and i. output_type == args.output_type) if not trans: raise Exception('translator not found') input_conf = json.decode(sys.stdin.read(), format=format) if trans.input_schema: json_schema_repo.validate(trans.input_schema, input_conf) output = trans.translate(input_conf) if trans.output_schema: json_schema_repo.validate(trans.output_schema, output) else: raise NotImplementedError() print(json.encode(output, format=format, indent=4))
def console_debug_cb(result: MatchResult, call_stack: MatchCallStack): """Simple console debugger.""" success = '+++' if result.node else '---' stack = ', '.join(frame.name for frame in call_stack) consumed = util.first(call_stack).data[:-len(result.rest)] print(success, stack) print('<<<', consumed) print('>>>', result.rest, flush=True)
def test_list(run_translator): result = run_translator('list', []) for i in translators: assert util.first( result, lambda x: (x['input_type'] == i.input_type) and (x['input_schema'] == i.input_schema) and (x['output_type'] == i.output_type) and (x['output_schema'] == i.output_schema))
def filter_events_by_subscriptions(events, subscriptions): ret = [] for event in events: if util.first( subscriptions, lambda q_type: hat.event.common.matches_query_type( event.event_type, q_type)) is not None: ret.append(event) return ret
async def test_archive(create_backend, create_msg, timestamp, db_path, short_register_delay, enable_archive): low_size = 50 high_size = 100 change_queue = aio.Queue() backend = await create_backend(low_size=low_size, high_size=high_size, enable_archive=enable_archive) backend.register_change_cb(change_queue.put_nowait) entries = [] for _ in range(high_size): await backend.register(timestamp, create_msg()) entries = await change_queue.get() + entries # wait for posible background db cleanup await asyncio.sleep(0.1) assert backend.last_id == high_size assert backend.first_id == 1 result = await backend.query(common.Filter()) assert len(result) == high_size count = len(list(db_path.parent.glob(f'{db_path.name}.*'))) assert count == 0 await backend.register(timestamp, create_msg()) entries = await change_queue.get() + entries # wait for expected background db cleanup await asyncio.sleep(0.1) assert backend.first_id == backend.last_id - low_size + 1 assert backend.last_id == high_size + 1 result = await backend.query(common.Filter()) assert len(result) == low_size count = len(list(db_path.parent.glob(f'{db_path.name}.*'))) assert count == (1 if enable_archive else 0) await backend.async_close() assert backend.is_closed if enable_archive: archive_path = util.first(db_path.parent.glob('*.*'), lambda i: i.name == f'{db_path.name}.1') backend = await create_backend(path=archive_path, high_size=high_size, low_size=low_size) assert not backend.is_closed entries_archived = await backend.query(common.Filter()) assert len(entries_archived) == (high_size - low_size + 1) assert result + entries_archived == entries await backend.async_close() assert backend.is_closed
def test_server_listens(monitor_factory): server_info = monitor_factory() connections = server_info.process.connections() for port in { server_info.ui_port, server_info.monitor_port, server_info.master_port }: assert util.first( connections, lambda c: (c.laddr.ip == '0.0.0.0' and c.laddr.port == port))
def _on_command(self, conn, cmd): self._logger.log(f'received command {cmd}') key = _value_to_type(cmd.value), cmd.asdu_address, cmd.io_address command_id, command = util.first( self._data['commands'].items(), lambda i: (i[1]['type'], i[1]['asdu'], i[1]['io']) == key, (None, None)) success = bool(command['success']) if command else False if success: self._data.set(['commands', command_id, 'value'], _value_to_json(cmd.value)) self._logger.log(f'sending command success {success}') return success
async def _event_loop(self): try: while True: events = await self._event_client.receive() killer = util.first(events, lambda i: i.event_type == ['a1', 'kill']) if killer: break for session in self._sessions: session._event_queue.put_nowait(events) finally: self._group.close()
def set_rank(self, cid: int, rank: int): """Set component rank""" info = util.first(self._local_components, lambda i: i.cid == cid) if not info or info.rank == rank: return updated_info = info._replace(rank=rank) self._local_components = [(updated_info if i is info else i) for i in self._local_components] if info.name is not None: self._rank_cache[info.name, info.group] = rank self._change_cbs.notify()
def test_first_example(): assert util.first(range(3)) == 0 assert util.first(range(3), lambda x: x > 1) == 2 assert util.first(range(3), lambda x: x > 2) is None assert util.first(range(3), lambda x: x > 2, 123) == 123 assert util.first({1: 'a', 2: 'b', 3: 'c'}) == 1 assert util.first([], default=123) == 123
def _update_global_on_local_components(self, local_components, mid): new_global_components = [ i for i in self._global_components if i.mid != mid ] for c in local_components: old_c = util.first(self._global_components, lambda i: i.cid == c.cid and i.mid == c.mid) if old_c: new_global_components.append( c._replace(blessing=old_c.blessing)) else: new_global_components.append(c) self._calculate_global_components(new_global_components)
def _set_components(self, msg_server): if (msg_server.data.module != 'HatMonitor' or msg_server.data.type != 'MsgServer'): raise Exception('Message received from server malformed: message ' 'MsgServer from HatMonitor module expected') self._components = [ common.component_info_from_sbs(i) for i in msg_server.data.data['components'] ] self._info = util.first( self._components, lambda i: i.cid == msg_server.data.data['cid'] and i.mid == msg_server.data.data['mid']) self._change_cbs.notify()
def _decode_external(content): entity = util.first( content.elements, lambda x: (x.class_type == common.ClassType.UNIVERSAL and x.tag_number == 6)) direct_ref = _decode_objectidentifier(entity.content) if entity else None entity = util.first( content.elements, lambda x: (x.class_type == common.ClassType.UNIVERSAL and x.tag_number == 2)) indirect_ref = _decode_integer(entity.content) if entity else None entity = content.elements[-1] if entity.tag_number == 0: data = entity.content.elements[0] elif entity.tag_number == 1: data = _decode_octetstring(entity.content) elif entity.tag_number == 2: data = _decode_bitstring(entity.content) else: raise ValueError('invalid external content') return common.External(data=data, direct_ref=direct_ref, indirect_ref=indirect_ref)
def _bless_one(components): global _last_token_id min_rank = min(i.rank for i in components) min_rank_components = [i for i in components if i.rank == min_rank] highlander = util.first(min_rank_components, lambda c: c.blessing is not None) if highlander: return {(i.cid, i.mid): highlander.blessing if i is highlander else None for i in components} if any(c.ready is not None for c in components): return {(i.cid, i.mid): None for i in components} highlander = min(min_rank_components, key=lambda i: i.mid) _last_token_id += 1 return {(i.cid, i.mid): _last_token_id if i is highlander else None for i in components}
async def test_ui_malformed_message(monitor_factory): monitor = monitor_factory() incorrect_ui_client = await hat.juggler.connect( f'ws://127.0.0.1:{monitor.ui_port}/ws') await asyncio.sleep(0.5) await incorrect_ui_client.send('JSON serializable data') await incorrect_ui_client.wait_closed() assert common.process_is_running(monitor.process) connections = monitor.process.connections() for port in {monitor.ui_port, monitor.monitor_port, monitor.master_port}: assert util.first( connections, lambda c: (c.laddr.ip == '0.0.0.0' and c.laddr.port == port))
async def _address_loop(monitor_client, server_group, address_queue): last_address = None changes = aio.Queue() with monitor_client.register_change_cb(lambda: changes.put_nowait(None)): while True: info = util.first( monitor_client.components, lambda c: (c.group == server_group and c.blessing is not None and c. blessing == c.ready)) address = info.address if info else None if address != last_address and not address_queue.is_closed: mlog.debug("new server address: %s", address) last_address = address address_queue.put_nowait(address) await changes.get()
def _set_client(self, cid, name, group, address, ready): info = util.first(self._local_components, lambda i: i.cid == cid) updated_info = info._replace(name=name, group=group, address=address, ready=ready) if info.name is None: rank_cache_key = name, group rank = self._rank_cache.get(rank_cache_key, info.rank) updated_info = updated_info._replace(rank=rank) if info == updated_info: return self._local_components = [(updated_info if i is info else i) for i in self._local_components] self._change_cbs.notify()
async def _idle_loop(self, conf, device_module, enabled): try: while True: new_events_future = self._async_group.spawn( self._event_queue.get) device_closed_future = (self._device.closed if self._device is not None else asyncio.Future()) await asyncio.wait([new_events_future, device_closed_future], return_when=asyncio.FIRST_COMPLETED) if device_closed_future.done(): return events = new_events_future.result() enable_event_type = [ *self._device_identifier_prefix, 'system', 'enable' ] enable_events = [ ev for ev in events if _check_bool_event(ev, enable_event_type) ] device_events = [ ev for ev in events if ev.event_type != enable_event_type ] if len(set([ev.payload.data for ev in enable_events])) > 1: mlog.warning( 'multiple distinct enable values set for device %s, ' 'retaining previous enable value (%s)', conf['name'], enabled) else: enable_event = util.first(enable_events) enabled = (enable_event.payload.data if enable_event else enabled) if enabled and self._device is None: await self._create_device(conf, device_module) elif not enabled and self._device is not None: await self._destroy_device() elif enabled and self._device is not None: self._device_event_client.add_events(device_events) except asyncio.CancelledError: mlog.debug('device idle loop cancelled') finally: self._async_group.close()
async def create_device_proxy(conf, client, gateway_name): """Create device proxy Args: conf (hat.json.Data): configuration defined by ``hat://gateway/main.yaml#/definitions/device`` client (hat.event.client.Client): event client gateway_name (str): gateway name Returns: DeviceProxy """ proxy = DeviceProxy() device_module = importlib.import_module(conf['module']) proxy._event_queue = aio.Queue() proxy._async_group = aio.Group(proxy._on_exception) proxy._client = client proxy._device_identifier_prefix = [ 'gateway', gateway_name, device_module.device_type, conf['name'] ] enable_event_type = [*proxy._device_identifier_prefix, 'system', 'enable'] enable_query = hat.event.common.QueryData(event_types=[enable_event_type], unique_type=True) enable_events = await client.query(enable_query) enable_event = util.first( enable_events, lambda ev: _check_bool_event(ev, enable_event_type)) enabled = enable_event.payload.data if enable_event is not None else False proxy._register_device_running_event(False) if enabled: await proxy._create_device(conf, device_module) else: proxy._device = None proxy._device_event_client = None proxy._async_group.spawn(proxy._idle_loop, conf, device_module, enabled) proxy._async_group.spawn(aio.call_on_cancel, proxy._cleanup) return proxy
def set_rank(self, cid, mid, rank): """Set component's rank Args: cid (int): component id mid (int): component's local monitor id rank (int): component's rank """ if self._master: self._master.set_rank(cid, mid, rank) if mid != self.mid: return self._change_local_component(cid, rank=rank) if not self._master: self._handle_local_changes() info = util.first(self._local_components, lambda i: i.cid == cid) if info is None: return self._rank_cache[info.name, info.group] = info.rank
async def test_ui_incorrect_type(cluster_factory): cluster = await cluster_factory({'group': {'components': ['c1', 'c2']}}) incorrect_ui_client = await hat.juggler.connect( f'ws://localhost:{cluster.server_info.ui_port}/ws') await asyncio.sleep(0.5) await incorrect_ui_client.send({'type': 'incorrect_type', 'payload': None}) await incorrect_ui_client.closed server_info = cluster.server_info assert common.process_is_running(server_info.process) connections = server_info.process.connections() for port in { server_info.ui_port, server_info.monitor_port, server_info.master_port }: assert util.first( connections, lambda c: (c.laddr.ip == '0.0.0.0' and c.laddr.port == port))