Example #1
0
    def expand_shards(
        self,
        shards: Optional[Collection[ShardID]] = None,
        node_ids: Optional[Collection[NodeID]] = None,
    ) -> Tuple[ShardID, ...]:
        shards = list(shards or [])
        node_ids = list(node_ids or [])
        for node_id in node_ids:
            shards.append(ShardID(node=node_id, shard_index=ALL_SHARDS))

        ret: Set[ShardID] = set()
        for shard in shards:
            node_view = self.get_node_view(
                node_index=shard.node.node_index, node_name=shard.node.name
            )
            if not node_view.is_storage:
                continue

            if shard.shard_index == ALL_SHARDS:
                r = range(0, node_view.num_shards)
            else:
                r = range(shard.shard_index, shard.shard_index + 1)

            for shard_index in r:
                ret.add(ShardID(node=node_view.node_id, shard_index=shard_index))

        return tuple(
            sorted(ret, key=lambda shard: (shard.node.node_index, shard.shard_index))
        )
    async def test_get_shard_last_updated_at(self):
        async with MockAdminAPI() as client:
            cv = await get_cluster_view(client)
            shard = ShardID(node=cv.get_node_view_by_node_index(0).node_id,
                            shard_index=1)
            await apply_maintenance(
                client=client,
                shards=[shard],
                shard_target_state=ShardOperationalState.DRAINED,
            )
            cv = await get_cluster_view(client)
            mv = list(cv.get_all_maintenance_views())[0]
            self.assertIsNone(mv.get_shard_last_updated_at(shard))

            ts = datetime.now()
            client._set_shard_maintenance_progress(
                shard,
                ShardMaintenanceProgress(
                    status=MaintenanceStatus.STARTED,
                    target_states=[ShardOperationalState.MAY_DISAPPEAR],
                    created_at=ts,
                    last_updated_at=datetime.now(),
                    associated_group_ids=["johnsnow"],
                ).to_thrift(),
            )
            cv = await get_cluster_view(client)
            mv = list(cv.get_all_maintenance_views())[0]
            self.assertTrue(
                mv.get_shard_last_updated_at(shard) -
                ts < timedelta(seconds=2))
Example #3
0
async def check_impact(
    client: AdminAPI,
    nodes: Optional[Collection[Node]] = None,
    shards: Optional[Collection[ShardID]] = None,
    target_storage_state: ShardStorageState = ShardStorageState.DISABLED,
    disable_sequencers: bool = True,
) -> CheckImpactResponse:
    """
    Performs Safety check and returns CheckImpactResponse. If no nodes and no
    shards passed it still does safety check, but will return current state
    of the cluster.
    """
    nodes = nodes or []
    shards = shards or []

    req_shards: FrozenSet[ShardID] = _recombine_shards(
        list(shards)  # shards is generic Collection, not List
        + [
            ShardID(
                node=NodeID(node_index=n.node_index,
                            address=n.data_addr.to_thrift()),
                shard_index=-1,
            ) for n in nodes
        ])

    return await admin_api.check_impact(
        client=client,
        req=CheckImpactRequest(
            shards=req_shards,
            target_storage_state=target_storage_state,
            disable_sequencers=disable_sequencers,
        ),
    )
Example #4
0
def parse_shards(src: Collection[str]) -> Set[ShardID]:
    """
    Parses a list of strings and intrepret as ShardID objects.
    Accepted examples:
        0 => ShardID(0, -1)
        N0 => ShardID(0, -1)
        0:2 => ShardID(0, 2)
        N0:2 => ShardID(0, 2)
        N0:S2 => ShardID(0, 2)
    """
    regex = re.compile(r"^N?(\d+)\:?S?(\d+)?$", flags=re.IGNORECASE)
    res = set()
    for s in src:
        match = regex.search(s)
        if not match:
            raise ValueError(f"Cannot parse shard: {s}")

        node_index = int(match.groups()[0])
        if match.groups()[1] is None:
            shard_index = ALL_SHARDS
        else:
            shard_index = int(match.groups()[1])
        res.add(
            ShardID(node=NodeID(node_index=node_index),
                    shard_index=shard_index))

    return res
Example #5
0
 async def test_smoke(self):
     async with MockAdminAPI() as client:
         cv = await get_cluster_view(client)
         await apply_maintenance(
             client=client,
             shards=[
                 ShardID(node=cv.get_node_view_by_node_index(0).node_id,
                         shard_index=1)
             ],
             sequencer_nodes=[cv.get_node_view_by_node_index(0).node_id],
         )
         await apply_maintenance(
             client=client,
             node_ids=[cv.get_node_id(node_index=1)],
             user="******",
             reason="whatever",
         )
         (cv, nc_resp, ns_resp, mnts_resp) = await asyncio.gather(
             get_cluster_view(client),
             client.getNodesConfig(NodesFilter()),
             client.getNodesState(NodesStateRequest()),
             client.getMaintenances(MaintenancesFilter()),
         )
     self._validate(cv, nc_resp.nodes, ns_resp.states,
                    tuple(mnts_resp.maintenances))
Example #6
0
async def check_impact(
    client: AdminAPI,
    node_ids: Optional[Collection[NodeID]] = None,
    shards: Optional[Collection[ShardID]] = None,
    target_storage_state: ShardStorageState = ShardStorageState.DISABLED,
    disable_sequencers: bool = True,
) -> CheckImpactResponse:
    """
    Performs Safety check and returns CheckImpactResponse. If no node_ids and no
    shards passed it still does safety check, but will return current state
    of the cluster.
    """
    node_ids = set(node_ids or [])
    shards = set(shards or [])

    req_shards: Set[ShardID] = _recombine_shards(
        shards.union(
            ShardID(node=n, shard_index=ALL_SHARDS) for n in node_ids))

    return await admin_api.check_impact(
        client=client,
        req=CheckImpactRequest(
            shards=list(req_shards),
            target_storage_state=target_storage_state,
            disable_sequencers=disable_sequencers,
        ),
    )
async def apply_maintenance(
    client: AdminAPI,
    node_ids: Optional[Collection[NodeID]] = None,
    shards: Optional[Collection[ShardID]] = None,
    shard_target_state: Optional[
        ShardOperationalState
    ] = ShardOperationalState.MAY_DISAPPEAR,
    sequencer_nodes: Optional[Collection[NodeID]] = None,
    group: Optional[bool] = True,
    ttl: Optional[timedelta] = None,
    user: Optional[str] = None,
    reason: Optional[str] = None,
    extras: Optional[Mapping[str, str]] = None,
    skip_safety_checks: Optional[bool] = False,
    allow_passive_drains: Optional[bool] = False,
) -> Collection[MaintenanceDefinition]:
    """
    Applies maintenance to MaintenanceManager.
    If `nodes` argument is specified, they're treated as shards and as
        sequencers simultaneously.

    Can return multiple maintenances if group==False.
    """
    node_ids = set(node_ids or [])
    shards = set(shards or [])
    sequencer_nodes = set(sequencer_nodes or []).union(node_ids)

    if ttl is None:
        ttl = timedelta(seconds=0)

    if user is None:
        user = "******"

    if reason is None:
        reason = "Not Specified"

    if extras is None:
        extras = {}

    shards = shards.union({ShardID(node=n, shard_index=-1) for n in node_ids})
    shards = _recombine_shards(shards)

    req = MaintenanceDefinition(
        shards=list(shards),
        shard_target_state=shard_target_state,
        sequencer_nodes=[n for n in sequencer_nodes],
        sequencer_target_state=SequencingState.DISABLED,
        user=user,
        reason=reason,
        extras=extras,
        skip_safety_checks=skip_safety_checks,
        group=group,
        ttl_seconds=int(ttl.total_seconds()),
        allow_passive_drains=allow_passive_drains,
    )
    resp: MaintenanceDefinitionResponse = await admin_api.apply_maintenance(
        client=client, req=req
    )
    return resp.maintenances
Example #8
0
    def test_parse_shards_valid2(self) -> None:
        # Parse multiple inputs
        self.assertEqual(
            {
                ShardID(node=NodeID(node_index=0), shard_index=1),
                ShardID(node=NodeID(node_index=1), shard_index=2),
            },
            helpers.parse_shards(["N0:S1", "N1:S2"]),
        )

        # Remove duplicates
        self.assertEqual(
            {
                ShardID(node=NodeID(node_index=0), shard_index=1),
                ShardID(node=NodeID(node_index=1), shard_index=2),
            },
            helpers.parse_shards(["N0:S1", "N1:S2", "N0:s1"]),
        )
Example #9
0
    async def applyMaintenance(
        self, request: MaintenanceDefinition
    ) -> MaintenanceDefinitionResponse:
        # TODO: ungroup if group == False
        shards = []
        for sh in request.shards:
            if sh.shard_index == -1:
                # TODO: make it unwrap
                pass
            else:
                assert sh.node.node_index is not None
                # pyre-fixme[6]: Expected `int` for 1st param but got `Optional[int]`.
                nc = self._nc_by_node_index[sh.node.node_index]
                shards.append(
                    ShardID(
                        node=NodeID(
                            node_index=nc.node_index,
                            name=nc.name,
                            address=nc.data_address,
                        ),
                        shard_index=sh.shard_index,
                    )
                )

        seq_nodes = []
        for n in request.sequencer_nodes:
            assert n.node_index is not None
            # pyre-fixme[6]: Expected `int` for 1st param but got `Optional[int]`.
            nc = self._nc_by_node_index[n.node_index]
            seq_nodes.append(
                NodeID(node_index=nc.node_index, name=nc.name, address=nc.data_address)
            )

        mnt = MaintenanceDefinition(
            shards=shards,
            shard_target_state=request.shard_target_state,
            sequencer_nodes=seq_nodes,
            sequencer_target_state=request.sequencer_target_state,
            user=request.user,
            reason=request.reason,
            extras=request.extras,
            skip_safety_checks=request.skip_safety_checks,
            force_restore_rebuilding=request.force_restore_rebuilding,
            group=request.group,
            ttl_seconds=request.ttl_seconds,
            allow_passive_drains=request.allow_passive_drains,
            group_id=gen_word(8),
            last_check_impact_result=None,
            expires_on=1000 * (int(datetime.now().timestamp()) + request.ttl_seconds)
            if request.ttl_seconds
            else None,
            created_on=1000 * int(datetime.now().timestamp()),
        )
        assert mnt.group_id is not None
        # pyre-fixme[6]: Expected `str` for 1st param but got `Optional[str]`.
        self._maintenances_by_id[mnt.group_id] = mnt
        return MaintenanceDefinitionResponse(maintenances=[mnt])
Example #10
0
    def _combine(
        cv: ClusterView,
        shards: Optional[List[str]] = None,
        node_names: Optional[List[str]] = None,
        node_indexes: Optional[List[int]] = None,
    ) -> Tuple[ShardID, ...]:

        shards = list(shards or [])
        node_names = list(node_names or [])
        node_indexes = list(node_indexes or [])

        shard_ids = parse_shards(shards)
        for nn in node_names:
            shard_ids.add(
                ShardID(node=cv.get_node_id(node_name=nn), shard_index=-1))
        for ni in node_indexes:
            shard_ids.add(ShardID(node=NodeID(node_index=ni), shard_index=-1))
        shard_ids_expanded = cv.expand_shards(shard_ids)

        return shard_ids_expanded
Example #11
0
    async def smoke(self, client) -> None:
        cv = await get_cluster_view(client)
        storages_node_views = [nv for nv in cv.get_all_node_views() if nv.is_storage]
        sequencers_node_views = [
            nv for nv in cv.get_all_node_views() if nv.is_sequencer
        ]

        # combined maintenance, storages and sequencers
        await apply_maintenance(
            client=client,
            shards=[ShardID(node=storages_node_views[0].node_id, shard_index=1)],
            sequencer_nodes=[sequencers_node_views[0].node_id],
        )

        # storage-only maintenance
        await apply_maintenance(
            client=client,
            shards=[ShardID(node=storages_node_views[1].node_id, shard_index=1)],
        )

        # sequencer-only maintenance
        await apply_maintenance(
            client=client, sequencer_nodes=[sequencers_node_views[2].node_id]
        )

        # maintenance for whole nodes
        await apply_maintenance(
            client=client,
            node_ids=[storages_node_views[3].node_id, sequencers_node_views[4].node_id],
            user="******",
            reason="whatever",
        )

        (cv, nc_resp, ns_resp, mnts_resp) = await asyncio.gather(
            get_cluster_view(client),
            client.getNodesConfig(NodesFilter()),
            client.getNodesState(NodesStateRequest()),
            client.getMaintenances(MaintenancesFilter()),
        )
        self._validate(cv, nc_resp.nodes, ns_resp.states, tuple(mnts_resp.maintenances))
Example #12
0
    async def test_smoke(self):
        ni = 0
        async with MockAdminAPI() as client:
            cv = await get_cluster_view(client)
            maintenances_resp = await apply_maintenance(
                client=client,
                shards=[
                    ShardID(
                        node=cv.get_node_view_by_node_index(0).node_id, shard_index=1
                    )
                ],
                sequencer_nodes=[cv.get_node_view_by_node_index(0).node_id],
            )
            (
                nodes_config_resp,
                nodes_state_resp,
                maintenances_resp,
            ) = await asyncio.gather(
                client.getNodesConfig(NodesFilter(node=NodeID(node_index=ni))),
                client.getNodesState(
                    NodesStateRequest(filter=NodesFilter(node=NodeID(node_index=ni)))
                ),
                client.getMaintenances(MaintenancesFilter()),
            )

        nc = [n for n in nodes_config_resp.nodes if n.node_index == ni][0]
        ns = [n for n in nodes_state_resp.states if n.node_index == ni][0]
        mnt_ids = set()
        for mnt in maintenances_resp.maintenances:
            for s in mnt.shards:
                if s.node.node_index == ni:
                    mnt_ids.add(mnt.group_id)
            for n in mnt.sequencer_nodes:
                if n.node_index == ni:
                    mnt_ids.add(mnt.group_id)
        mnts = tuple(
            sorted(
                (
                    mnt
                    for mnt in maintenances_resp.maintenances
                    if mnt.group_id in mnt_ids
                ),
                key=operator.attrgetter("group_id"),
            )
        )

        nv = NodeView(node_config=nc, node_state=ns, maintenances=mnts)

        self._validate(nv, nc, ns, mnts)
    async def test_shard_only(self):
        async with MockAdminAPI() as client:
            cv = await get_cluster_view(client)
            await apply_maintenance(
                client=client,
                shards=[
                    ShardID(node=cv.get_node_view_by_node_index(0).node_id,
                            shard_index=1)
                ],
            )
            cv = await get_cluster_view(client)

        self.validate(
            maintenance_view=list(cv.get_all_maintenance_views())[0],
            maintenance=list(cv.get_all_maintenances())[0],
            node_index_to_node_view={0: cv.get_node_view(node_index=0)},
        )
    async def test_node_is_not_a_sequencer(self):
        async with MockAdminAPI(disaggregated=True) as client:
            cv = await get_cluster_view(client)
            shard = ShardID(node=cv.get_node_view_by_node_index(0).node_id,
                            shard_index=1)
            await apply_maintenance(
                client=client,
                shards=[shard],
                shard_target_state=ShardOperationalState.DRAINED,
            )
            cv = await get_cluster_view(client)

        mv = list(cv.get_all_maintenance_views())[0]
        node_id = cv.get_node_view_by_node_index(0).node_id
        with self.assertRaises(NodeIsNotASequencerError):
            mv.get_sequencer_maintenance_status(node_id)

        with self.assertRaises(NodeIsNotASequencerError):
            mv.get_sequencer_last_updated_at(node_id)
Example #15
0
def to_shard_id(scope: str) -> ShardID:
    """
    A conversion utility that takes a Nx:Sy string and convert it into the
    typed ShardID. The 'Sy' part is optional and if unset the generated
    ShardID will have a shard_index set to -1
    """
    scope = scope.upper()
    if not scope:
        raise ValueError(f"Cannot parse empty scope")
    match = SHARD_PATTERN.match(scope)
    if match is None:
        # There were no shards, or invalid.
        raise ValueError(f"Cannot parse '{scope}'. Invalid format!")

    results = match.groupdict()
    shard_index = -1
    if results["shard_index"] is not None:
        shard_index = int(results["shard_index"])
    node_index = int(results["node_index"])
    node = NodeID(node_index=node_index)
    return ShardID(node=node, shard_index=shard_index)
Example #16
0
    def test_parse_shards_valid1(self) -> None:
        # 5
        self.assertEqual(
            {ShardID(node=NodeID(node_index=5), shard_index=ALL_SHARDS)},
            helpers.parse_shards(["5"]),
        )
        # 5:1
        self.assertEqual(
            {ShardID(node=NodeID(node_index=5), shard_index=1)},
            helpers.parse_shards(["5:1"]),
        )
        # 0:S1
        self.assertEqual(
            {ShardID(node=NodeID(node_index=0), shard_index=1)},
            helpers.parse_shards(["0:S1"]),
        )
        # N0:S1
        self.assertEqual(
            {ShardID(node=NodeID(node_index=0), shard_index=1)},
            helpers.parse_shards(["N0:S1"]),
        )

        # N0 == ShardID(0, ALL_SHARDS)
        self.assertEqual(
            {ShardID(node=NodeID(node_index=0), shard_index=ALL_SHARDS)},
            helpers.parse_shards(["N0"]),
        )

        # N1:S4 == ShardID(1, 4)
        self.assertEqual(
            {ShardID(node=NodeID(node_index=1), shard_index=4)},
            helpers.parse_shards(["N1:S4"]),
        )

        # Allow ignored case
        # n1:S4 == ShardID(1, 4)
        self.assertEqual(
            {ShardID(node=NodeID(node_index=1), shard_index=4)},
            helpers.parse_shards(["n1:S4"]),
        )
Example #17
0
    async def apply(
        self,
        reason: str,
        node_indexes: Optional[List[int]] = None,
        node_names: Optional[List[str]] = None,
        shards: Optional[List[str]] = None,
        shard_target_state: Optional[str] = "may-disappear",
        sequencer_node_indexes: Optional[List[int]] = None,
        sequencer_node_names: Optional[List[str]] = None,
        user: Optional[str] = "",
        group: Optional[bool] = True,
        skip_safety_checks: Optional[bool] = False,
        ttl: Optional[int] = 0,
        allow_passive_drains: Optional[bool] = False,
        force_restore_rebuilding: Optional[bool] = False,
    ):
        """
        Applies new maintenance to Maintenance Manager
        """
        ctx = context.get_context()

        try:
            async with ctx.get_cluster_admin_client() as client:
                cv = await get_cluster_view(client)

            all_node_indexes = set()
            if node_indexes is not None:
                all_node_indexes = all_node_indexes.union(set(node_indexes))
            if node_names is not None:
                all_node_indexes = all_node_indexes.union({
                    cv.get_node_index(node_name=node_name)
                    for node_name in set(node_names)
                })

            shard_ids = set()
            sequencer_nodes = set()
            for ni in all_node_indexes:
                nv = cv.get_node_view(node_index=ni)
                if nv.is_storage:
                    shard_ids.add(ShardID(node=nv.node_id, shard_index=-1))
                if nv.is_sequencer:
                    sequencer_nodes.add(nv.node_id)

            if sequencer_node_indexes is not None:
                for ni in set(sequencer_node_indexes):
                    nv = cv.get_node_view(node_index=ni)
                    if nv.is_sequencer:
                        sequencer_nodes.add(nv.node_id)

            if sequencer_node_names is not None:
                for nn in set(sequencer_node_names):
                    nv = cv.get_node_view(node_name=nn)
                    if nv.is_sequencer:
                        sequencer_nodes.add(nv.node_id)

            if shards is not None:
                shard_ids = shard_ids.union(
                    cv.expand_shards(parse_shards(shards)))

        except NodeNotFoundError as e:
            print(colored(f"Node not found: {e}", "red"))
            return

        try:
            async with ctx.get_cluster_admin_client() as client:
                maintenances: Collection[MaintenanceDefinition]
                maintenances = await apply_maintenance(
                    client=client,
                    shards=shard_ids,
                    shard_target_state=_parse_shard_target_state(
                        shard_target_state),
                    sequencer_nodes=list(sequencer_nodes),
                    group=group,
                    ttl=timedelta(seconds=ttl),
                    user=user or getuser(),
                    reason=reason,
                    skip_safety_checks=skip_safety_checks,
                    allow_passive_drains=allow_passive_drains,
                    force_restore_rebuilding=force_restore_rebuilding,
                )
                cv = await get_cluster_view(client)
        except Exception as e:
            print(colored(f"Cannot apply maintenance: {e}", "red"))
            return

        print(
            _render(
                [
                    cv.get_maintenance_view_by_id(id)
                    for id in [mnt.group_id for mnt in maintenances]
                ],
                cv,
                mode=RenderingMode.EXPANDED,
            ))
Example #18
0
    def _validate(
        self,
        cv: ClusterView,
        ncs: List[NodeConfig],
        nss: List[NodeState],
        mnts: Tuple[MaintenanceDefinition, ...],
    ):
        nis = sorted(nc.node_index for nc in ncs)
        ni_to_nc = {nc.node_index: nc for nc in ncs}
        ni_to_ns = {ns.node_index: ns for ns in nss}
        ni_to_mnts: Dict[int,
                         List[MaintenanceDefinition]] = {ni: []
                                                         for ni in nis}
        for mnt in mnts:
            mnt_nis = set()
            for s in mnt.shards:
                assert s.node.node_index is not None
                mnt_nis.add(s.node.node_index)
            for n in mnt.sequencer_nodes:
                assert n.node_index is not None
                mnt_nis.add(n.node_index)
            for ni in mnt_nis:
                ni_to_mnts[ni].append(mnt)

        self.assertEqual(sorted(cv.get_all_node_indexes()),
                         sorted(ni_to_nc.keys()))

        self.assertEqual(
            sorted(cv.get_all_node_views(),
                   key=operator.attrgetter("node_index")),
            sorted(
                (NodeView(
                    node_config=ni_to_nc[ni],
                    node_state=ni_to_ns[ni],
                    maintenances=tuple(ni_to_mnts[ni]),
                ) for ni in ni_to_nc.keys()),
                key=operator.attrgetter("node_index"),
            ),
        )

        self.assertEqual(sorted(cv.get_all_node_names()),
                         sorted(nc.name for nc in ncs))

        self.assertEqual(sorted(cv.get_all_maintenance_ids()),
                         sorted(mnt.group_id for mnt in mnts))

        self.assertEqual(
            sorted(cv.get_all_maintenances(),
                   key=operator.attrgetter("group_id")),
            sorted(mnts, key=operator.attrgetter("group_id")),
        )

        for ni in nis:
            nn = ni_to_nc[ni].name
            nc = ni_to_nc[ni]
            ns = ni_to_ns[ni]
            node_mnts = tuple(ni_to_mnts[ni])
            nv = NodeView(
                node_config=ni_to_nc[ni],
                node_state=ni_to_ns[ni],
                maintenances=node_mnts,
            )

            self.assertEqual(cv.get_node_view_by_node_index(ni), nv)
            self.assertEqual(cv.get_node_name_by_node_index(ni), nn)
            self.assertEqual(cv.get_node_config_by_node_index(ni), nc)
            self.assertEqual(cv.get_node_state_by_node_index(ni), ns)
            self.assertEqual(cv.get_node_maintenances_by_node_index(ni),
                             node_mnts)

            self.assertEqual(cv.get_node_view_by_node_name(nn), nv)
            self.assertEqual(cv.get_node_index_by_node_name(nn), ni)
            self.assertEqual(cv.get_node_config_by_node_name(nn), nc)
            self.assertEqual(cv.get_node_state_by_node_name(nn), ns)
            self.assertEqual(cv.get_node_maintenances_by_node_name(nn),
                             node_mnts)

            self.assertEqual(cv.get_node_view(node_name=nn), nv)
            self.assertEqual(cv.get_node_index(node_name=nn), ni)
            self.assertEqual(cv.get_node_config(node_name=nn), nc)
            self.assertEqual(cv.get_node_state(node_name=nn), ns)
            self.assertEqual(cv.get_node_maintenances(node_name=nn), node_mnts)

            self.assertEqual(cv.get_node_view(node_index=ni), nv)
            self.assertEqual(cv.get_node_name(node_index=ni), nn)
            self.assertEqual(cv.get_node_config(node_index=ni), nc)
            self.assertEqual(cv.get_node_state(node_index=ni), ns)
            self.assertEqual(cv.get_node_maintenances(node_index=ni),
                             node_mnts)

        with self.assertRaises(ValueError):
            cv.get_node_view(None, None)

        with self.assertRaises(ValueError):
            cv.get_node_config(None, None)

        with self.assertRaises(ValueError):
            cv.get_node_state(None, None)

        with self.assertRaises(ValueError):
            cv.get_node_maintenances(None, None)

        # mismatch node_index and node_name
        if len(nis) > 1:
            nn = ni_to_nc[nis[0]].name
            ni = nis[1]
            with self.assertRaises(ValueError):
                cv.get_node_view(ni, nn)

            with self.assertRaises(ValueError):
                cv.get_node_config(ni, nn)

            with self.assertRaises(ValueError):
                cv.get_node_state(ni, nn)

            with self.assertRaises(ValueError):
                cv.get_node_maintenances(ni, nn)

        # non-existent node_index
        with self.assertRaises(NodeNotFoundError):
            cv.get_node_view(node_index=max(nis) + 1)

        # non-existent node_name
        with self.assertRaises(NodeNotFoundError):
            nns = {nc.name for nc in ncs}
            while True:
                nn = gen_word()
                if nn not in nns:
                    break
            cv.get_node_view(node_name=nn)

        for mnt in mnts:
            assert mnt.group_id is not None
            self.assertEqual(cv.get_maintenance_by_id(mnt.group_id), mnt)
            self.assertTupleEqual(
                cv.get_node_indexes_by_maintenance_id(mnt.group_id),
                tuple(
                    sorted(
                        set({
                            n.node_index
                            for n in mnt.sequencer_nodes
                            if n.node_index is not None
                        }).union({
                            s.node.node_index
                            for s in mnt.shards
                            if s.node.node_index is not None
                        }))),
            )
            self.assertEqual(
                mnt.group_id,
                cv.get_maintenance_view_by_id(mnt.group_id).group_id)

        self.assertListEqual(
            list(sorted(m.group_id for m in mnts)),
            list(sorted(mv.group_id for mv in cv.get_all_maintenance_views())),
        )

        # expand_shards
        self.assertEqual(
            cv.expand_shards(shards=[
                ShardID(node=NodeID(node_index=nis[0]), shard_index=0)
            ]),
            (ShardID(
                node=NodeID(
                    node_index=ni_to_nc[nis[0]].node_index,
                    name=ni_to_nc[nis[0]].name,
                    address=ni_to_nc[nis[0]].data_address,
                ),
                shard_index=0,
            ), ),
        )
        self.assertEqual(
            len(
                cv.expand_shards(shards=[
                    ShardID(node=NodeID(node_index=nis[0]),
                            shard_index=ALL_SHARDS)
                ])),
            ni_to_nc[nis[0]].storage.num_shards,
        )
        self.assertEqual(
            len(
                cv.expand_shards(shards=[
                    ShardID(node=NodeID(node_index=nis[0]),
                            shard_index=ALL_SHARDS),
                    ShardID(node=NodeID(node_index=nis[0]),
                            shard_index=ALL_SHARDS),
                    ShardID(node=NodeID(node_index=nis[1]),
                            shard_index=ALL_SHARDS),
                ])),
            ni_to_nc[nis[0]].storage.num_shards +
            ni_to_nc[nis[1]].storage.num_shards,
        )
        self.assertEqual(
            len(
                cv.expand_shards(
                    shards=[
                        ShardID(node=NodeID(node_index=nis[0]),
                                shard_index=ALL_SHARDS),
                        ShardID(node=NodeID(node_index=nis[1]), shard_index=0),
                    ],
                    node_ids=[NodeID(node_index=0)],
                )),
            ni_to_nc[nis[0]].storage.num_shards + 1,
        )

        # normalize_node_id
        self.assertEqual(
            cv.normalize_node_id(NodeID(node_index=nis[0])),
            NodeID(
                node_index=nis[0],
                address=ni_to_nc[nis[0]].data_address,
                name=ni_to_nc[nis[0]].name,
            ),
        )
        self.assertEqual(
            cv.normalize_node_id(NodeID(name=ni_to_nc[nis[0]].name)),
            NodeID(
                node_index=nis[0],
                address=ni_to_nc[nis[0]].data_address,
                name=ni_to_nc[nis[0]].name,
            ),
        )

        # search_maintenances
        self.assertEqual(len(cv.search_maintenances()), len(mnts))
        self.assertEqual(
            len(cv.search_maintenances(node_ids=[cv.get_node_id(
                node_index=3)])), 0)
        self.assertEqual(
            len(cv.search_maintenances(node_ids=[cv.get_node_id(
                node_index=1)])), 1)
        self.assertEqual(
            len(
                cv.search_maintenances(shards=[
                    ShardID(node=cv.get_node_id(node_index=0), shard_index=1)
                ])),
            1,
        )

        # shard_target_state
        self.assertEqual(
            len(
                cv.search_maintenances(
                    shard_target_state=ShardOperationalState.MAY_DISAPPEAR)),
            2,
        )
        self.assertEqual(
            len(
                cv.search_maintenances(
                    shard_target_state=ShardOperationalState.DRAINED)),
            0,
        )

        # sequencer_target_state
        self.assertEqual(
            len(
                cv.search_maintenances(
                    sequencer_target_state=SequencingState.ENABLED)),
            0,
        )
        self.assertEqual(
            len(
                cv.search_maintenances(
                    sequencer_target_state=SequencingState.DISABLED)),
            2,
        )

        self.assertEqual(len(cv.search_maintenances(user="******")), 1)
        self.assertEqual(len(cv.search_maintenances(reason="whatever")), 1)

        self.assertEqual(len(cv.search_maintenances(skip_safety_checks=True)),
                         0)
        self.assertEqual(len(cv.search_maintenances(skip_safety_checks=False)),
                         2)

        self.assertEqual(
            len(cv.search_maintenances(force_restore_rebuilding=True)), 0)
        self.assertEqual(
            len(cv.search_maintenances(force_restore_rebuilding=False)), 2)

        self.assertEqual(
            len(cv.search_maintenances(allow_passive_drains=True)), 0)
        self.assertEqual(
            len(cv.search_maintenances(allow_passive_drains=False)), 2)

        self.assertEqual(
            len(cv.search_maintenances(group_id=mnts[0].group_id)), 1)

        self.assertEqual(
            len(
                cv.search_maintenances(
                    progress=MaintenanceProgress.IN_PROGRESS)), 2)
    async def test_shard_maintenance_status(self):
        ## MAY_DISAPPEAR maintenance
        async with MockAdminAPI() as client:
            cv = await get_cluster_view(client)
            shard = ShardID(node=cv.get_node_view_by_node_index(0).node_id,
                            shard_index=1)
            await apply_maintenance(
                client=client,
                shards=[shard],
                shard_target_state=ShardOperationalState.MAY_DISAPPEAR,
            )
            cv = await get_cluster_view(client)

            # Just started
            mv = list(cv.get_all_maintenance_views())[0]
            self.assertEqual(mv.get_shard_maintenance_status(shard),
                             MaintenanceStatus.NOT_STARTED)
            self.assertEqual(mv.num_shards_done, 0)
            self.assertFalse(mv.are_all_shards_done)
            self.assertFalse(mv.is_everything_done)
            self.assertFalse(mv.is_blocked)
            self.assertFalse(mv.is_completed)
            self.assertTrue(mv.is_in_progress)
            self.assertFalse(mv.is_internal)
            self.assertEqual(mv.overall_status,
                             MaintenanceOverallStatus.IN_PROGRESS)

            # In progress
            client._set_shard_maintenance_progress(
                shard,
                ShardMaintenanceProgress(
                    status=MaintenanceStatus.STARTED,
                    target_states=[ShardOperationalState.MAY_DISAPPEAR],
                    created_at=datetime.now(),
                    last_updated_at=datetime.now(),
                    associated_group_ids=["johnsnow"],
                ).to_thrift(),
            )
            cv = await get_cluster_view(client)
            mv = list(cv.get_all_maintenance_views())[0]
            self.assertEqual(mv.get_shard_maintenance_status(shard),
                             MaintenanceStatus.STARTED)
            self.assertEqual(mv.num_shards_done, 0)
            self.assertFalse(mv.are_all_shards_done)
            self.assertFalse(mv.is_everything_done)
            self.assertFalse(mv.is_blocked)
            self.assertFalse(mv.is_completed)
            self.assertTrue(mv.is_in_progress)
            self.assertFalse(mv.is_internal)
            self.assertEqual(mv.overall_status,
                             MaintenanceOverallStatus.IN_PROGRESS)

            # Blocked
            client._set_shard_maintenance_progress(
                shard,
                ShardMaintenanceProgress(
                    status=MaintenanceStatus.BLOCKED_UNTIL_SAFE,
                    target_states=[ShardOperationalState.MAY_DISAPPEAR],
                    created_at=datetime.now(),
                    last_updated_at=datetime.now(),
                    associated_group_ids=["johnsnow"],
                ).to_thrift(),
            )
            cv = await get_cluster_view(client)
            mv = list(cv.get_all_maintenance_views())[0]
            self.assertTrue(mv.is_blocked)
            self.assertFalse(mv.are_all_shards_done)
            self.assertFalse(mv.is_everything_done)
            self.assertTrue(mv.is_blocked)
            self.assertFalse(mv.is_completed)
            self.assertFalse(mv.is_in_progress)
            self.assertEqual(mv.overall_status,
                             MaintenanceOverallStatus.BLOCKED)

            # Done
            for sos in {
                    ShardOperationalState.DRAINED,
                    ShardOperationalState.MAY_DISAPPEAR,
                    ShardOperationalState.MIGRATING_DATA,
                    ShardOperationalState.PROVISIONING,
            }:
                client._set_shard_current_operational_state(shard, sos)
                cv = await get_cluster_view(client)
                mv = list(cv.get_all_maintenance_views())[0]
                self.assertEqual(mv.get_shard_maintenance_status(shard),
                                 MaintenanceStatus.COMPLETED)
                self.assertEqual(mv.num_shards_done, 1)
                self.assertTrue(mv.are_all_shards_done)
                self.assertTrue(mv.is_everything_done)
                self.assertFalse(mv.is_blocked)
                self.assertTrue(mv.is_completed)
                self.assertFalse(mv.is_in_progress)
                self.assertEqual(mv.overall_status,
                                 MaintenanceOverallStatus.COMPLETED)

        ## DRAINED maintenance
        async with MockAdminAPI() as client:
            cv = await get_cluster_view(client)
            shard = ShardID(node=cv.get_node_view_by_node_index(0).node_id,
                            shard_index=1)
            await apply_maintenance(
                client=client,
                shards=[shard],
                shard_target_state=ShardOperationalState.DRAINED,
            )
            cv = await get_cluster_view(client)

            # Just started
            mv = list(cv.get_all_maintenance_views())[0]
            self.assertEqual(mv.get_shard_maintenance_status(shard),
                             MaintenanceStatus.NOT_STARTED)
            self.assertEqual(mv.num_shards_done, 0)
            self.assertFalse(mv.are_all_shards_done)

            # May disappear
            client._set_shard_current_operational_state(
                shard, ShardOperationalState.MAY_DISAPPEAR)
            cv = await get_cluster_view(client)
            mv = list(cv.get_all_maintenance_views())[0]
            self.assertEqual(mv.get_shard_maintenance_status(shard),
                             MaintenanceStatus.NOT_STARTED)
            self.assertEqual(mv.num_shards_done, 0)
            self.assertFalse(mv.are_all_shards_done)

            # Done
            client._set_shard_current_operational_state(
                shard, ShardOperationalState.DRAINED)
            cv = await get_cluster_view(client)
            mv = list(cv.get_all_maintenance_views())[0]
            self.assertEqual(mv.get_shard_maintenance_status(shard),
                             MaintenanceStatus.COMPLETED)
            self.assertEqual(mv.num_shards_done, 1)
            self.assertTrue(mv.are_all_shards_done)
Example #20
0
    def search_maintenances(
        self,
        node_ids: Optional[Collection[NodeID]] = None,
        shards: Optional[Collection[ShardID]] = None,
        shard_target_state: Optional[ShardOperationalState] = None,
        sequencer_nodes: Optional[Collection[NodeID]] = None,
        sequencer_target_state: Optional[SequencingState] = None,
        user: Optional[str] = None,
        reason: Optional[str] = None,
        skip_safety_checks: Optional[bool] = None,
        force_restore_rebuilding: Optional[bool] = None,
        allow_passive_drains: Optional[bool] = None,
        group_id: Optional[str] = None,
        progress: Optional[MaintenanceProgress] = None,
    ) -> Tuple[MaintenanceView, ...]:
        mvs = self.get_all_maintenance_views()

        if node_ids is not None:
            sequencer_nodes = list(sequencer_nodes or []) + list(node_ids)
            shards = list(shards or []) + [
                ShardID(node=node_id, shard_index=ALL_SHARDS) for node_id in node_ids
            ]

        if shards is not None:
            search_shards = self.expand_shards(shards)
            mvs = (mv for mv in mvs if self.expand_shards(mv.shards) == search_shards)

        if shard_target_state is not None:
            mvs = (mv for mv in mvs if shard_target_state == mv.shard_target_state)

        if sequencer_nodes is not None:
            normalized_sequencer_node_indexes = tuple(
                sorted(self.normalize_node_id(n).node_index for n in sequencer_nodes)
            )
            mvs = (
                mv
                for mv in mvs
                if normalized_sequencer_node_indexes
                == mv.affected_sequencer_node_indexes
            )

        if sequencer_target_state is not None:
            mvs = (
                mv for mv in mvs if mv.sequencer_target_state == sequencer_target_state
            )

        if user is not None:
            mvs = (mv for mv in mvs if mv.user == user)

        if reason is not None:
            mvs = (mv for mv in mvs if mv.reason == reason)

        if skip_safety_checks is not None:
            mvs = (mv for mv in mvs if mv.skip_safety_checks == skip_safety_checks)

        if force_restore_rebuilding is not None:
            mvs = (
                mv
                for mv in mvs
                if mv.force_restore_rebuilding == force_restore_rebuilding
            )

        if allow_passive_drains is not None:
            mvs = (mv for mv in mvs if mv.allow_passive_drains == allow_passive_drains)

        if group_id is not None:
            mvs = (mv for mv in mvs if mv.group_id == group_id)

        if progress is not None:
            mvs = (mv for mv in mvs if mv.progress == progress)

        return tuple(mvs)