def expand_shards( self, shards: Optional[Collection[ShardID]] = None, node_ids: Optional[Collection[NodeID]] = None, ) -> Tuple[ShardID, ...]: shards = list(shards or []) node_ids = list(node_ids or []) for node_id in node_ids: shards.append(ShardID(node=node_id, shard_index=ALL_SHARDS)) ret: Set[ShardID] = set() for shard in shards: node_view = self.get_node_view( node_index=shard.node.node_index, node_name=shard.node.name ) if not node_view.is_storage: continue if shard.shard_index == ALL_SHARDS: r = range(0, node_view.num_shards) else: r = range(shard.shard_index, shard.shard_index + 1) for shard_index in r: ret.add(ShardID(node=node_view.node_id, shard_index=shard_index)) return tuple( sorted(ret, key=lambda shard: (shard.node.node_index, shard.shard_index)) )
async def test_get_shard_last_updated_at(self): async with MockAdminAPI() as client: cv = await get_cluster_view(client) shard = ShardID(node=cv.get_node_view_by_node_index(0).node_id, shard_index=1) await apply_maintenance( client=client, shards=[shard], shard_target_state=ShardOperationalState.DRAINED, ) cv = await get_cluster_view(client) mv = list(cv.get_all_maintenance_views())[0] self.assertIsNone(mv.get_shard_last_updated_at(shard)) ts = datetime.now() client._set_shard_maintenance_progress( shard, ShardMaintenanceProgress( status=MaintenanceStatus.STARTED, target_states=[ShardOperationalState.MAY_DISAPPEAR], created_at=ts, last_updated_at=datetime.now(), associated_group_ids=["johnsnow"], ).to_thrift(), ) cv = await get_cluster_view(client) mv = list(cv.get_all_maintenance_views())[0] self.assertTrue( mv.get_shard_last_updated_at(shard) - ts < timedelta(seconds=2))
async def check_impact( client: AdminAPI, nodes: Optional[Collection[Node]] = None, shards: Optional[Collection[ShardID]] = None, target_storage_state: ShardStorageState = ShardStorageState.DISABLED, disable_sequencers: bool = True, ) -> CheckImpactResponse: """ Performs Safety check and returns CheckImpactResponse. If no nodes and no shards passed it still does safety check, but will return current state of the cluster. """ nodes = nodes or [] shards = shards or [] req_shards: FrozenSet[ShardID] = _recombine_shards( list(shards) # shards is generic Collection, not List + [ ShardID( node=NodeID(node_index=n.node_index, address=n.data_addr.to_thrift()), shard_index=-1, ) for n in nodes ]) return await admin_api.check_impact( client=client, req=CheckImpactRequest( shards=req_shards, target_storage_state=target_storage_state, disable_sequencers=disable_sequencers, ), )
def parse_shards(src: Collection[str]) -> Set[ShardID]: """ Parses a list of strings and intrepret as ShardID objects. Accepted examples: 0 => ShardID(0, -1) N0 => ShardID(0, -1) 0:2 => ShardID(0, 2) N0:2 => ShardID(0, 2) N0:S2 => ShardID(0, 2) """ regex = re.compile(r"^N?(\d+)\:?S?(\d+)?$", flags=re.IGNORECASE) res = set() for s in src: match = regex.search(s) if not match: raise ValueError(f"Cannot parse shard: {s}") node_index = int(match.groups()[0]) if match.groups()[1] is None: shard_index = ALL_SHARDS else: shard_index = int(match.groups()[1]) res.add( ShardID(node=NodeID(node_index=node_index), shard_index=shard_index)) return res
async def test_smoke(self): async with MockAdminAPI() as client: cv = await get_cluster_view(client) await apply_maintenance( client=client, shards=[ ShardID(node=cv.get_node_view_by_node_index(0).node_id, shard_index=1) ], sequencer_nodes=[cv.get_node_view_by_node_index(0).node_id], ) await apply_maintenance( client=client, node_ids=[cv.get_node_id(node_index=1)], user="******", reason="whatever", ) (cv, nc_resp, ns_resp, mnts_resp) = await asyncio.gather( get_cluster_view(client), client.getNodesConfig(NodesFilter()), client.getNodesState(NodesStateRequest()), client.getMaintenances(MaintenancesFilter()), ) self._validate(cv, nc_resp.nodes, ns_resp.states, tuple(mnts_resp.maintenances))
async def check_impact( client: AdminAPI, node_ids: Optional[Collection[NodeID]] = None, shards: Optional[Collection[ShardID]] = None, target_storage_state: ShardStorageState = ShardStorageState.DISABLED, disable_sequencers: bool = True, ) -> CheckImpactResponse: """ Performs Safety check and returns CheckImpactResponse. If no node_ids and no shards passed it still does safety check, but will return current state of the cluster. """ node_ids = set(node_ids or []) shards = set(shards or []) req_shards: Set[ShardID] = _recombine_shards( shards.union( ShardID(node=n, shard_index=ALL_SHARDS) for n in node_ids)) return await admin_api.check_impact( client=client, req=CheckImpactRequest( shards=list(req_shards), target_storage_state=target_storage_state, disable_sequencers=disable_sequencers, ), )
async def apply_maintenance( client: AdminAPI, node_ids: Optional[Collection[NodeID]] = None, shards: Optional[Collection[ShardID]] = None, shard_target_state: Optional[ ShardOperationalState ] = ShardOperationalState.MAY_DISAPPEAR, sequencer_nodes: Optional[Collection[NodeID]] = None, group: Optional[bool] = True, ttl: Optional[timedelta] = None, user: Optional[str] = None, reason: Optional[str] = None, extras: Optional[Mapping[str, str]] = None, skip_safety_checks: Optional[bool] = False, allow_passive_drains: Optional[bool] = False, ) -> Collection[MaintenanceDefinition]: """ Applies maintenance to MaintenanceManager. If `nodes` argument is specified, they're treated as shards and as sequencers simultaneously. Can return multiple maintenances if group==False. """ node_ids = set(node_ids or []) shards = set(shards or []) sequencer_nodes = set(sequencer_nodes or []).union(node_ids) if ttl is None: ttl = timedelta(seconds=0) if user is None: user = "******" if reason is None: reason = "Not Specified" if extras is None: extras = {} shards = shards.union({ShardID(node=n, shard_index=-1) for n in node_ids}) shards = _recombine_shards(shards) req = MaintenanceDefinition( shards=list(shards), shard_target_state=shard_target_state, sequencer_nodes=[n for n in sequencer_nodes], sequencer_target_state=SequencingState.DISABLED, user=user, reason=reason, extras=extras, skip_safety_checks=skip_safety_checks, group=group, ttl_seconds=int(ttl.total_seconds()), allow_passive_drains=allow_passive_drains, ) resp: MaintenanceDefinitionResponse = await admin_api.apply_maintenance( client=client, req=req ) return resp.maintenances
def test_parse_shards_valid2(self) -> None: # Parse multiple inputs self.assertEqual( { ShardID(node=NodeID(node_index=0), shard_index=1), ShardID(node=NodeID(node_index=1), shard_index=2), }, helpers.parse_shards(["N0:S1", "N1:S2"]), ) # Remove duplicates self.assertEqual( { ShardID(node=NodeID(node_index=0), shard_index=1), ShardID(node=NodeID(node_index=1), shard_index=2), }, helpers.parse_shards(["N0:S1", "N1:S2", "N0:s1"]), )
async def applyMaintenance( self, request: MaintenanceDefinition ) -> MaintenanceDefinitionResponse: # TODO: ungroup if group == False shards = [] for sh in request.shards: if sh.shard_index == -1: # TODO: make it unwrap pass else: assert sh.node.node_index is not None # pyre-fixme[6]: Expected `int` for 1st param but got `Optional[int]`. nc = self._nc_by_node_index[sh.node.node_index] shards.append( ShardID( node=NodeID( node_index=nc.node_index, name=nc.name, address=nc.data_address, ), shard_index=sh.shard_index, ) ) seq_nodes = [] for n in request.sequencer_nodes: assert n.node_index is not None # pyre-fixme[6]: Expected `int` for 1st param but got `Optional[int]`. nc = self._nc_by_node_index[n.node_index] seq_nodes.append( NodeID(node_index=nc.node_index, name=nc.name, address=nc.data_address) ) mnt = MaintenanceDefinition( shards=shards, shard_target_state=request.shard_target_state, sequencer_nodes=seq_nodes, sequencer_target_state=request.sequencer_target_state, user=request.user, reason=request.reason, extras=request.extras, skip_safety_checks=request.skip_safety_checks, force_restore_rebuilding=request.force_restore_rebuilding, group=request.group, ttl_seconds=request.ttl_seconds, allow_passive_drains=request.allow_passive_drains, group_id=gen_word(8), last_check_impact_result=None, expires_on=1000 * (int(datetime.now().timestamp()) + request.ttl_seconds) if request.ttl_seconds else None, created_on=1000 * int(datetime.now().timestamp()), ) assert mnt.group_id is not None # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`. self._maintenances_by_id[mnt.group_id] = mnt return MaintenanceDefinitionResponse(maintenances=[mnt])
def _combine( cv: ClusterView, shards: Optional[List[str]] = None, node_names: Optional[List[str]] = None, node_indexes: Optional[List[int]] = None, ) -> Tuple[ShardID, ...]: shards = list(shards or []) node_names = list(node_names or []) node_indexes = list(node_indexes or []) shard_ids = parse_shards(shards) for nn in node_names: shard_ids.add( ShardID(node=cv.get_node_id(node_name=nn), shard_index=-1)) for ni in node_indexes: shard_ids.add(ShardID(node=NodeID(node_index=ni), shard_index=-1)) shard_ids_expanded = cv.expand_shards(shard_ids) return shard_ids_expanded
async def smoke(self, client) -> None: cv = await get_cluster_view(client) storages_node_views = [nv for nv in cv.get_all_node_views() if nv.is_storage] sequencers_node_views = [ nv for nv in cv.get_all_node_views() if nv.is_sequencer ] # combined maintenance, storages and sequencers await apply_maintenance( client=client, shards=[ShardID(node=storages_node_views[0].node_id, shard_index=1)], sequencer_nodes=[sequencers_node_views[0].node_id], ) # storage-only maintenance await apply_maintenance( client=client, shards=[ShardID(node=storages_node_views[1].node_id, shard_index=1)], ) # sequencer-only maintenance await apply_maintenance( client=client, sequencer_nodes=[sequencers_node_views[2].node_id] ) # maintenance for whole nodes await apply_maintenance( client=client, node_ids=[storages_node_views[3].node_id, sequencers_node_views[4].node_id], user="******", reason="whatever", ) (cv, nc_resp, ns_resp, mnts_resp) = await asyncio.gather( get_cluster_view(client), client.getNodesConfig(NodesFilter()), client.getNodesState(NodesStateRequest()), client.getMaintenances(MaintenancesFilter()), ) self._validate(cv, nc_resp.nodes, ns_resp.states, tuple(mnts_resp.maintenances))
async def test_smoke(self): ni = 0 async with MockAdminAPI() as client: cv = await get_cluster_view(client) maintenances_resp = await apply_maintenance( client=client, shards=[ ShardID( node=cv.get_node_view_by_node_index(0).node_id, shard_index=1 ) ], sequencer_nodes=[cv.get_node_view_by_node_index(0).node_id], ) ( nodes_config_resp, nodes_state_resp, maintenances_resp, ) = await asyncio.gather( client.getNodesConfig(NodesFilter(node=NodeID(node_index=ni))), client.getNodesState( NodesStateRequest(filter=NodesFilter(node=NodeID(node_index=ni))) ), client.getMaintenances(MaintenancesFilter()), ) nc = [n for n in nodes_config_resp.nodes if n.node_index == ni][0] ns = [n for n in nodes_state_resp.states if n.node_index == ni][0] mnt_ids = set() for mnt in maintenances_resp.maintenances: for s in mnt.shards: if s.node.node_index == ni: mnt_ids.add(mnt.group_id) for n in mnt.sequencer_nodes: if n.node_index == ni: mnt_ids.add(mnt.group_id) mnts = tuple( sorted( ( mnt for mnt in maintenances_resp.maintenances if mnt.group_id in mnt_ids ), key=operator.attrgetter("group_id"), ) ) nv = NodeView(node_config=nc, node_state=ns, maintenances=mnts) self._validate(nv, nc, ns, mnts)
async def test_shard_only(self): async with MockAdminAPI() as client: cv = await get_cluster_view(client) await apply_maintenance( client=client, shards=[ ShardID(node=cv.get_node_view_by_node_index(0).node_id, shard_index=1) ], ) cv = await get_cluster_view(client) self.validate( maintenance_view=list(cv.get_all_maintenance_views())[0], maintenance=list(cv.get_all_maintenances())[0], node_index_to_node_view={0: cv.get_node_view(node_index=0)}, )
async def test_node_is_not_a_sequencer(self): async with MockAdminAPI(disaggregated=True) as client: cv = await get_cluster_view(client) shard = ShardID(node=cv.get_node_view_by_node_index(0).node_id, shard_index=1) await apply_maintenance( client=client, shards=[shard], shard_target_state=ShardOperationalState.DRAINED, ) cv = await get_cluster_view(client) mv = list(cv.get_all_maintenance_views())[0] node_id = cv.get_node_view_by_node_index(0).node_id with self.assertRaises(NodeIsNotASequencerError): mv.get_sequencer_maintenance_status(node_id) with self.assertRaises(NodeIsNotASequencerError): mv.get_sequencer_last_updated_at(node_id)
def to_shard_id(scope: str) -> ShardID: """ A conversion utility that takes a Nx:Sy string and convert it into the typed ShardID. The 'Sy' part is optional and if unset the generated ShardID will have a shard_index set to -1 """ scope = scope.upper() if not scope: raise ValueError(f"Cannot parse empty scope") match = SHARD_PATTERN.match(scope) if match is None: # There were no shards, or invalid. raise ValueError(f"Cannot parse '{scope}'. Invalid format!") results = match.groupdict() shard_index = -1 if results["shard_index"] is not None: shard_index = int(results["shard_index"]) node_index = int(results["node_index"]) node = NodeID(node_index=node_index) return ShardID(node=node, shard_index=shard_index)
def test_parse_shards_valid1(self) -> None: # 5 self.assertEqual( {ShardID(node=NodeID(node_index=5), shard_index=ALL_SHARDS)}, helpers.parse_shards(["5"]), ) # 5:1 self.assertEqual( {ShardID(node=NodeID(node_index=5), shard_index=1)}, helpers.parse_shards(["5:1"]), ) # 0:S1 self.assertEqual( {ShardID(node=NodeID(node_index=0), shard_index=1)}, helpers.parse_shards(["0:S1"]), ) # N0:S1 self.assertEqual( {ShardID(node=NodeID(node_index=0), shard_index=1)}, helpers.parse_shards(["N0:S1"]), ) # N0 == ShardID(0, ALL_SHARDS) self.assertEqual( {ShardID(node=NodeID(node_index=0), shard_index=ALL_SHARDS)}, helpers.parse_shards(["N0"]), ) # N1:S4 == ShardID(1, 4) self.assertEqual( {ShardID(node=NodeID(node_index=1), shard_index=4)}, helpers.parse_shards(["N1:S4"]), ) # Allow ignored case # n1:S4 == ShardID(1, 4) self.assertEqual( {ShardID(node=NodeID(node_index=1), shard_index=4)}, helpers.parse_shards(["n1:S4"]), )
async def apply( self, reason: str, node_indexes: Optional[List[int]] = None, node_names: Optional[List[str]] = None, shards: Optional[List[str]] = None, shard_target_state: Optional[str] = "may-disappear", sequencer_node_indexes: Optional[List[int]] = None, sequencer_node_names: Optional[List[str]] = None, user: Optional[str] = "", group: Optional[bool] = True, skip_safety_checks: Optional[bool] = False, ttl: Optional[int] = 0, allow_passive_drains: Optional[bool] = False, force_restore_rebuilding: Optional[bool] = False, ): """ Applies new maintenance to Maintenance Manager """ ctx = context.get_context() try: async with ctx.get_cluster_admin_client() as client: cv = await get_cluster_view(client) all_node_indexes = set() if node_indexes is not None: all_node_indexes = all_node_indexes.union(set(node_indexes)) if node_names is not None: all_node_indexes = all_node_indexes.union({ cv.get_node_index(node_name=node_name) for node_name in set(node_names) }) shard_ids = set() sequencer_nodes = set() for ni in all_node_indexes: nv = cv.get_node_view(node_index=ni) if nv.is_storage: shard_ids.add(ShardID(node=nv.node_id, shard_index=-1)) if nv.is_sequencer: sequencer_nodes.add(nv.node_id) if sequencer_node_indexes is not None: for ni in set(sequencer_node_indexes): nv = cv.get_node_view(node_index=ni) if nv.is_sequencer: sequencer_nodes.add(nv.node_id) if sequencer_node_names is not None: for nn in set(sequencer_node_names): nv = cv.get_node_view(node_name=nn) if nv.is_sequencer: sequencer_nodes.add(nv.node_id) if shards is not None: shard_ids = shard_ids.union( cv.expand_shards(parse_shards(shards))) except NodeNotFoundError as e: print(colored(f"Node not found: {e}", "red")) return try: async with ctx.get_cluster_admin_client() as client: maintenances: Collection[MaintenanceDefinition] maintenances = await apply_maintenance( client=client, shards=shard_ids, shard_target_state=_parse_shard_target_state( shard_target_state), sequencer_nodes=list(sequencer_nodes), group=group, ttl=timedelta(seconds=ttl), user=user or getuser(), reason=reason, skip_safety_checks=skip_safety_checks, allow_passive_drains=allow_passive_drains, force_restore_rebuilding=force_restore_rebuilding, ) cv = await get_cluster_view(client) except Exception as e: print(colored(f"Cannot apply maintenance: {e}", "red")) return print( _render( [ cv.get_maintenance_view_by_id(id) for id in [mnt.group_id for mnt in maintenances] ], cv, mode=RenderingMode.EXPANDED, ))
def _validate( self, cv: ClusterView, ncs: List[NodeConfig], nss: List[NodeState], mnts: Tuple[MaintenanceDefinition, ...], ): nis = sorted(nc.node_index for nc in ncs) ni_to_nc = {nc.node_index: nc for nc in ncs} ni_to_ns = {ns.node_index: ns for ns in nss} ni_to_mnts: Dict[int, List[MaintenanceDefinition]] = {ni: [] for ni in nis} for mnt in mnts: mnt_nis = set() for s in mnt.shards: assert s.node.node_index is not None mnt_nis.add(s.node.node_index) for n in mnt.sequencer_nodes: assert n.node_index is not None mnt_nis.add(n.node_index) for ni in mnt_nis: ni_to_mnts[ni].append(mnt) self.assertEqual(sorted(cv.get_all_node_indexes()), sorted(ni_to_nc.keys())) self.assertEqual( sorted(cv.get_all_node_views(), key=operator.attrgetter("node_index")), sorted( (NodeView( node_config=ni_to_nc[ni], node_state=ni_to_ns[ni], maintenances=tuple(ni_to_mnts[ni]), ) for ni in ni_to_nc.keys()), key=operator.attrgetter("node_index"), ), ) self.assertEqual(sorted(cv.get_all_node_names()), sorted(nc.name for nc in ncs)) self.assertEqual(sorted(cv.get_all_maintenance_ids()), sorted(mnt.group_id for mnt in mnts)) self.assertEqual( sorted(cv.get_all_maintenances(), key=operator.attrgetter("group_id")), sorted(mnts, key=operator.attrgetter("group_id")), ) for ni in nis: nn = ni_to_nc[ni].name nc = ni_to_nc[ni] ns = ni_to_ns[ni] node_mnts = tuple(ni_to_mnts[ni]) nv = NodeView( node_config=ni_to_nc[ni], node_state=ni_to_ns[ni], maintenances=node_mnts, ) self.assertEqual(cv.get_node_view_by_node_index(ni), nv) self.assertEqual(cv.get_node_name_by_node_index(ni), nn) self.assertEqual(cv.get_node_config_by_node_index(ni), nc) self.assertEqual(cv.get_node_state_by_node_index(ni), ns) self.assertEqual(cv.get_node_maintenances_by_node_index(ni), node_mnts) self.assertEqual(cv.get_node_view_by_node_name(nn), nv) self.assertEqual(cv.get_node_index_by_node_name(nn), ni) self.assertEqual(cv.get_node_config_by_node_name(nn), nc) self.assertEqual(cv.get_node_state_by_node_name(nn), ns) self.assertEqual(cv.get_node_maintenances_by_node_name(nn), node_mnts) self.assertEqual(cv.get_node_view(node_name=nn), nv) self.assertEqual(cv.get_node_index(node_name=nn), ni) self.assertEqual(cv.get_node_config(node_name=nn), nc) self.assertEqual(cv.get_node_state(node_name=nn), ns) self.assertEqual(cv.get_node_maintenances(node_name=nn), node_mnts) self.assertEqual(cv.get_node_view(node_index=ni), nv) self.assertEqual(cv.get_node_name(node_index=ni), nn) self.assertEqual(cv.get_node_config(node_index=ni), nc) self.assertEqual(cv.get_node_state(node_index=ni), ns) self.assertEqual(cv.get_node_maintenances(node_index=ni), node_mnts) with self.assertRaises(ValueError): cv.get_node_view(None, None) with self.assertRaises(ValueError): cv.get_node_config(None, None) with self.assertRaises(ValueError): cv.get_node_state(None, None) with self.assertRaises(ValueError): cv.get_node_maintenances(None, None) # mismatch node_index and node_name if len(nis) > 1: nn = ni_to_nc[nis[0]].name ni = nis[1] with self.assertRaises(ValueError): cv.get_node_view(ni, nn) with self.assertRaises(ValueError): cv.get_node_config(ni, nn) with self.assertRaises(ValueError): cv.get_node_state(ni, nn) with self.assertRaises(ValueError): cv.get_node_maintenances(ni, nn) # non-existent node_index with self.assertRaises(NodeNotFoundError): cv.get_node_view(node_index=max(nis) + 1) # non-existent node_name with self.assertRaises(NodeNotFoundError): nns = {nc.name for nc in ncs} while True: nn = gen_word() if nn not in nns: break cv.get_node_view(node_name=nn) for mnt in mnts: assert mnt.group_id is not None self.assertEqual(cv.get_maintenance_by_id(mnt.group_id), mnt) self.assertTupleEqual( cv.get_node_indexes_by_maintenance_id(mnt.group_id), tuple( sorted( set({ n.node_index for n in mnt.sequencer_nodes if n.node_index is not None }).union({ s.node.node_index for s in mnt.shards if s.node.node_index is not None }))), ) self.assertEqual( mnt.group_id, cv.get_maintenance_view_by_id(mnt.group_id).group_id) self.assertListEqual( list(sorted(m.group_id for m in mnts)), list(sorted(mv.group_id for mv in cv.get_all_maintenance_views())), ) # expand_shards self.assertEqual( cv.expand_shards(shards=[ ShardID(node=NodeID(node_index=nis[0]), shard_index=0) ]), (ShardID( node=NodeID( node_index=ni_to_nc[nis[0]].node_index, name=ni_to_nc[nis[0]].name, address=ni_to_nc[nis[0]].data_address, ), shard_index=0, ), ), ) self.assertEqual( len( cv.expand_shards(shards=[ ShardID(node=NodeID(node_index=nis[0]), shard_index=ALL_SHARDS) ])), ni_to_nc[nis[0]].storage.num_shards, ) self.assertEqual( len( cv.expand_shards(shards=[ ShardID(node=NodeID(node_index=nis[0]), shard_index=ALL_SHARDS), ShardID(node=NodeID(node_index=nis[0]), shard_index=ALL_SHARDS), ShardID(node=NodeID(node_index=nis[1]), shard_index=ALL_SHARDS), ])), ni_to_nc[nis[0]].storage.num_shards + ni_to_nc[nis[1]].storage.num_shards, ) self.assertEqual( len( cv.expand_shards( shards=[ ShardID(node=NodeID(node_index=nis[0]), shard_index=ALL_SHARDS), ShardID(node=NodeID(node_index=nis[1]), shard_index=0), ], node_ids=[NodeID(node_index=0)], )), ni_to_nc[nis[0]].storage.num_shards + 1, ) # normalize_node_id self.assertEqual( cv.normalize_node_id(NodeID(node_index=nis[0])), NodeID( node_index=nis[0], address=ni_to_nc[nis[0]].data_address, name=ni_to_nc[nis[0]].name, ), ) self.assertEqual( cv.normalize_node_id(NodeID(name=ni_to_nc[nis[0]].name)), NodeID( node_index=nis[0], address=ni_to_nc[nis[0]].data_address, name=ni_to_nc[nis[0]].name, ), ) # search_maintenances self.assertEqual(len(cv.search_maintenances()), len(mnts)) self.assertEqual( len(cv.search_maintenances(node_ids=[cv.get_node_id( node_index=3)])), 0) self.assertEqual( len(cv.search_maintenances(node_ids=[cv.get_node_id( node_index=1)])), 1) self.assertEqual( len( cv.search_maintenances(shards=[ ShardID(node=cv.get_node_id(node_index=0), shard_index=1) ])), 1, ) # shard_target_state self.assertEqual( len( cv.search_maintenances( shard_target_state=ShardOperationalState.MAY_DISAPPEAR)), 2, ) self.assertEqual( len( cv.search_maintenances( shard_target_state=ShardOperationalState.DRAINED)), 0, ) # sequencer_target_state self.assertEqual( len( cv.search_maintenances( sequencer_target_state=SequencingState.ENABLED)), 0, ) self.assertEqual( len( cv.search_maintenances( sequencer_target_state=SequencingState.DISABLED)), 2, ) self.assertEqual(len(cv.search_maintenances(user="******")), 1) self.assertEqual(len(cv.search_maintenances(reason="whatever")), 1) self.assertEqual(len(cv.search_maintenances(skip_safety_checks=True)), 0) self.assertEqual(len(cv.search_maintenances(skip_safety_checks=False)), 2) self.assertEqual( len(cv.search_maintenances(force_restore_rebuilding=True)), 0) self.assertEqual( len(cv.search_maintenances(force_restore_rebuilding=False)), 2) self.assertEqual( len(cv.search_maintenances(allow_passive_drains=True)), 0) self.assertEqual( len(cv.search_maintenances(allow_passive_drains=False)), 2) self.assertEqual( len(cv.search_maintenances(group_id=mnts[0].group_id)), 1) self.assertEqual( len( cv.search_maintenances( progress=MaintenanceProgress.IN_PROGRESS)), 2)
async def test_shard_maintenance_status(self): ## MAY_DISAPPEAR maintenance async with MockAdminAPI() as client: cv = await get_cluster_view(client) shard = ShardID(node=cv.get_node_view_by_node_index(0).node_id, shard_index=1) await apply_maintenance( client=client, shards=[shard], shard_target_state=ShardOperationalState.MAY_DISAPPEAR, ) cv = await get_cluster_view(client) # Just started mv = list(cv.get_all_maintenance_views())[0] self.assertEqual(mv.get_shard_maintenance_status(shard), MaintenanceStatus.NOT_STARTED) self.assertEqual(mv.num_shards_done, 0) self.assertFalse(mv.are_all_shards_done) self.assertFalse(mv.is_everything_done) self.assertFalse(mv.is_blocked) self.assertFalse(mv.is_completed) self.assertTrue(mv.is_in_progress) self.assertFalse(mv.is_internal) self.assertEqual(mv.overall_status, MaintenanceOverallStatus.IN_PROGRESS) # In progress client._set_shard_maintenance_progress( shard, ShardMaintenanceProgress( status=MaintenanceStatus.STARTED, target_states=[ShardOperationalState.MAY_DISAPPEAR], created_at=datetime.now(), last_updated_at=datetime.now(), associated_group_ids=["johnsnow"], ).to_thrift(), ) cv = await get_cluster_view(client) mv = list(cv.get_all_maintenance_views())[0] self.assertEqual(mv.get_shard_maintenance_status(shard), MaintenanceStatus.STARTED) self.assertEqual(mv.num_shards_done, 0) self.assertFalse(mv.are_all_shards_done) self.assertFalse(mv.is_everything_done) self.assertFalse(mv.is_blocked) self.assertFalse(mv.is_completed) self.assertTrue(mv.is_in_progress) self.assertFalse(mv.is_internal) self.assertEqual(mv.overall_status, MaintenanceOverallStatus.IN_PROGRESS) # Blocked client._set_shard_maintenance_progress( shard, ShardMaintenanceProgress( status=MaintenanceStatus.BLOCKED_UNTIL_SAFE, target_states=[ShardOperationalState.MAY_DISAPPEAR], created_at=datetime.now(), last_updated_at=datetime.now(), associated_group_ids=["johnsnow"], ).to_thrift(), ) cv = await get_cluster_view(client) mv = list(cv.get_all_maintenance_views())[0] self.assertTrue(mv.is_blocked) self.assertFalse(mv.are_all_shards_done) self.assertFalse(mv.is_everything_done) self.assertTrue(mv.is_blocked) self.assertFalse(mv.is_completed) self.assertFalse(mv.is_in_progress) self.assertEqual(mv.overall_status, MaintenanceOverallStatus.BLOCKED) # Done for sos in { ShardOperationalState.DRAINED, ShardOperationalState.MAY_DISAPPEAR, ShardOperationalState.MIGRATING_DATA, ShardOperationalState.PROVISIONING, }: client._set_shard_current_operational_state(shard, sos) cv = await get_cluster_view(client) mv = list(cv.get_all_maintenance_views())[0] self.assertEqual(mv.get_shard_maintenance_status(shard), MaintenanceStatus.COMPLETED) self.assertEqual(mv.num_shards_done, 1) self.assertTrue(mv.are_all_shards_done) self.assertTrue(mv.is_everything_done) self.assertFalse(mv.is_blocked) self.assertTrue(mv.is_completed) self.assertFalse(mv.is_in_progress) self.assertEqual(mv.overall_status, MaintenanceOverallStatus.COMPLETED) ## DRAINED maintenance async with MockAdminAPI() as client: cv = await get_cluster_view(client) shard = ShardID(node=cv.get_node_view_by_node_index(0).node_id, shard_index=1) await apply_maintenance( client=client, shards=[shard], shard_target_state=ShardOperationalState.DRAINED, ) cv = await get_cluster_view(client) # Just started mv = list(cv.get_all_maintenance_views())[0] self.assertEqual(mv.get_shard_maintenance_status(shard), MaintenanceStatus.NOT_STARTED) self.assertEqual(mv.num_shards_done, 0) self.assertFalse(mv.are_all_shards_done) # May disappear client._set_shard_current_operational_state( shard, ShardOperationalState.MAY_DISAPPEAR) cv = await get_cluster_view(client) mv = list(cv.get_all_maintenance_views())[0] self.assertEqual(mv.get_shard_maintenance_status(shard), MaintenanceStatus.NOT_STARTED) self.assertEqual(mv.num_shards_done, 0) self.assertFalse(mv.are_all_shards_done) # Done client._set_shard_current_operational_state( shard, ShardOperationalState.DRAINED) cv = await get_cluster_view(client) mv = list(cv.get_all_maintenance_views())[0] self.assertEqual(mv.get_shard_maintenance_status(shard), MaintenanceStatus.COMPLETED) self.assertEqual(mv.num_shards_done, 1) self.assertTrue(mv.are_all_shards_done)
def search_maintenances( self, node_ids: Optional[Collection[NodeID]] = None, shards: Optional[Collection[ShardID]] = None, shard_target_state: Optional[ShardOperationalState] = None, sequencer_nodes: Optional[Collection[NodeID]] = None, sequencer_target_state: Optional[SequencingState] = None, user: Optional[str] = None, reason: Optional[str] = None, skip_safety_checks: Optional[bool] = None, force_restore_rebuilding: Optional[bool] = None, allow_passive_drains: Optional[bool] = None, group_id: Optional[str] = None, progress: Optional[MaintenanceProgress] = None, ) -> Tuple[MaintenanceView, ...]: mvs = self.get_all_maintenance_views() if node_ids is not None: sequencer_nodes = list(sequencer_nodes or []) + list(node_ids) shards = list(shards or []) + [ ShardID(node=node_id, shard_index=ALL_SHARDS) for node_id in node_ids ] if shards is not None: search_shards = self.expand_shards(shards) mvs = (mv for mv in mvs if self.expand_shards(mv.shards) == search_shards) if shard_target_state is not None: mvs = (mv for mv in mvs if shard_target_state == mv.shard_target_state) if sequencer_nodes is not None: normalized_sequencer_node_indexes = tuple( sorted(self.normalize_node_id(n).node_index for n in sequencer_nodes) ) mvs = ( mv for mv in mvs if normalized_sequencer_node_indexes == mv.affected_sequencer_node_indexes ) if sequencer_target_state is not None: mvs = ( mv for mv in mvs if mv.sequencer_target_state == sequencer_target_state ) if user is not None: mvs = (mv for mv in mvs if mv.user == user) if reason is not None: mvs = (mv for mv in mvs if mv.reason == reason) if skip_safety_checks is not None: mvs = (mv for mv in mvs if mv.skip_safety_checks == skip_safety_checks) if force_restore_rebuilding is not None: mvs = ( mv for mv in mvs if mv.force_restore_rebuilding == force_restore_rebuilding ) if allow_passive_drains is not None: mvs = (mv for mv in mvs if mv.allow_passive_drains == allow_passive_drains) if group_id is not None: mvs = (mv for mv in mvs if mv.group_id == group_id) if progress is not None: mvs = (mv for mv in mvs if mv.progress == progress) return tuple(mvs)