Ejemplo n.º 1
0
def __resolve_peering_config_variables(controller, config):
    """Resolve peering config variables.

    Arguments:
        controller {dict}
        config {dict}

    Raises:
        CohesiveSDKException: [description]

    Returns:
        [dict]
    """
    peering = controller.get("peering", {})
    peers = peering.get("peers", {})
    local_vars = controller.get("variables", {})
    if not peers:
        return config

    format_errors = []
    for peer_id, peer_name in peers.items():
        peer_name, err = __resolve_string_vars(peer_name, local_vars, config)
        if err:
            format_errors.append(err)
        peers[peer_id] = peer_name

    if format_errors:
        raise CohesiveSDKException(
            "Failed to format peers: %s" % ",".join(format_errors)
        )
    return config
Ejemplo n.º 2
0
def __resolve_plugins_config_variables(controller, config):
    """Resolve Plugin images for controller

    Arguments:
        controller {dict}
        config {dict}

    Raises:
        CohesiveSDKException: [description]

    Returns:
        [dict]
    """
    plugins = controller.get("plugins", [])
    local_vars = controller.get("variables", {})
    if not plugins:
        return config

    format_errors = []
    for plugin in plugins:
        for key, val in plugin.items():
            val_resolved, err = __resolve_string_vars(val, local_vars, config)
            if err:
                format_errors.append(err)
                continue
            plugin[key] = val_resolved

    if format_errors:
        raise CohesiveSDKException(
            "Failed to resolve controller plugin vars: %s" % ",".join(format_errors)
        )
    return config
Ejemplo n.º 3
0
def search_images(client, image_name):
    """Search plugin images by name

    Arguments:
        client {VNS3Client}
        image_name {str}

    Raises:
        CohesiveSDKException

    Returns:
        ContainerImage
    """
    resp_data = client.network_edge_plugins.get_container_system_images()
    images = resp_data.response.images
    if images is None:
        raise CohesiveSDKException("Container system is not running")
    non_null_images = list(filter(None, images))
    if len(non_null_images) == 0:
        return None

    for image in non_null_images:
        if image.get("image_name").lower() == image_name.lower():
            return image
    return None
Ejemplo n.º 4
0
def __resolve_route_config_variables(controller, config):
    """Resolve variables in route kwargs

    Arguments:
        controller {dict}
        config {dict}

    Raises:
        CohesiveSDKException: [description]

    Returns:
        [dict]
    """
    routes = controller.get("routes", [])
    if not routes:
        return config

    local_vars = controller.get("variables", {})
    format_errors = []
    for i, route_kwargs in enumerate(routes):
        for key, val in route_kwargs.items():
            val, err = __resolve_string_vars(val, local_vars, config)
            if err:
                format_errors.append(err)
            route_kwargs.update(**{key: val})

    if format_errors:
        raise CohesiveSDKException(
            "Failed to format routes: %s" % ",".join(format_errors)
        )
    return config
Ejemplo n.º 5
0
def read_config_file(config_file):
    try:
        return json.loads(open(config_file).read())
    except json.decoder.JSONDecodeError:
        raise CohesiveSDKException(
            "Invalid config file %s. Must be valid json." % config_file
        )
Ejemplo n.º 6
0
def wait_for_images_ready(client,
                          import_uuids=None,
                          interval=1.0,
                          timeout=120.0):
    """Wait for images to be Ready. Defaults to waiting for all.

    Arguments:
        client {VNS3Client}

    Keyword Arguments:
        import_uuids {List[str]} - list of import uuids to filter on
        timeout {float}

    Raises:
        CohesiveSDKException

    Returns:
        Bool
    """
    if not import_uuids:
        return True

    start_time = time.time()
    resp_data = client.network_edge_plugins.get_container_system_images()
    images = resp_data.response.images
    if not images:
        if import_uuids is not None:
            raise CohesiveSDKException("No container images found.")
        return True

    if all([i.get("status") == "Ready" for i in images]):
        return True

    time.sleep(interval)
    while time.time() - start_time < timeout:
        resp_data = client.network_edge_plugins.get_container_system_images()
        images = resp_data.response.images
        if all([
                i.get("status") == "Ready" for i in images
                if i.get("import_id") in import_uuids
        ]):
            return True
        time.sleep(interval)

    raise CohesiveSDKException(
        "Timeout: Images failed to enter ready state [timeout=%s seconds, host=%s]"
        % (timeout, client.host_uri))
Ejemplo n.º 7
0
def __add_controller_states(config, infra_state, groups=None):
    """
    Merge config and infrastructure states

    Arguments:
        config {dict}
        infra_state {dict}

    Keyword Arguments:
        groups {dict} - name: str -> size: int

    Raises:
        CohesiveSDKException: [description]

    Returns:
        [type] -- [description]
    """
    if not infra_state:
        return config

    controllers = config["controllers"]

    _group_indexes = None
    if groups:
        _group_indexes = {env: 0 for env in groups}
        missing_states = [g for g in groups if g not in infra_state]
        if missing_states:
            raise CohesiveSDKException(
                "If groups are provided. infra_state must be keyed by group. "
                "e.g. groups={aws: 3, azure: 2}, infra_state={aws: {...}, azure: {...}}"
            )

    for i, controller in enumerate(controllers):
        controller_vars = controller.get("variables", {})
        if groups:
            group = controller_vars["group"]
            if group not in infra_state:
                continue

            group_size = groups[group]
            _cur_group_index = _group_indexes[group]
            if _cur_group_index >= group_size:
                continue

            group_state = infra_state[group]
            for key, val in group_state.items():
                controller_vars[key] = val[_cur_group_index]
                _group_indexes[group] = _cur_group_index + 1
        else:
            for key, val in infra_state.items():
                assert type(val) in (
                    tuple,
                    list,
                ), "Expected infra state vars to be lists, indexed by controller"
                if len(val) > i:
                    controller_vars[key] = val[i]

        controller.update(variables=controller_vars)
    return config
Ejemplo n.º 8
0
def __substitute_controller_variables(config):
    """Substitute variables and set defaults for config

    Arguments:
        config {dict}

    Raises:
        Exception: [description]

    Returns:
        [dict]
    """
    global_variables = config.get("variables", {})
    set_master_password = global_variables.get("set_master_password")
    master_password = global_variables.get("master_password")
    for controller in config["controllers"]:
        # add global variables controller state, overriding with local variables
        local_variables = dict(global_variables, **controller.get("variables", {}))

        # Password logic:
        #   - use passwords passed by ENV
        #   - if none, use passwords in config file
        #   - if still none >
        #   -   if master password is passed and set_master_password flag is NOT true, use master
        #   -   else use default password for clouds
        if not local_variables.get("api_password"):
            if not master_password or set_master_password:
                cloud = local_variables.get("cloud", None)
                if cloud == "azure":
                    local_variables["api_password"] = "******" % (
                        local_variables["instance_name"],
                        local_variables["primary_private_ip"],
                    )
                else:
                    local_variables["api_password"] = local_variables["instance_id"]
            else:
                local_variables["api_password"] = master_password

        if not local_variables.get("host"):
            local_variables["host"] = local_variables["public_ip"]

        for key, value in local_variables.items():
            if util.is_formattable_string(value):
                try:
                    local_variables[key] = value.format(**local_variables)
                except KeyError:
                    raise CohesiveSDKException("Missing variable %s" % value)

        controller["variables"] = local_variables
        __resolve_route_config_variables(controller, config)
        __resolve_peering_config_variables(controller, config)
        __resolve_plugins_config_variables(controller, config)
    return config
Ejemplo n.º 9
0
def create_route_table(client: VNS3Client, routes, state={}):
    """Create routing policy

    Arguments:
        client {VNS3Client}
        routes {List[Route]} - [{
            "cidr": "str",
            "description": "str",
            "interface": "str",
            "gateway": "str",
            "tunnel": "int",
            "advertise": "bool",
            "metric": "int",
        }, ...]

    Keyword Arguments:
        state {dict} - State to format routes with. (can call client.state)

    Returns:
        Tuple[List[str], List[str]] - success, errors
    """
    successes = []
    errors = []
    Logger.debug(
        "Setting controller route table.",
        host=client.host_uri,
        route_count=len(routes),
    )

    _sub_vars = state or client.state
    for i, route_kwargs in enumerate(routes):
        skip = False
        for key, value in route_kwargs.items():
            _value, err = util.format_string(value, _sub_vars)
            if err:
                errors.append("Route key %s not formattable." % key)
                skip = True
            else:
                route_kwargs.update(**{key: _value})

        if skip:
            continue

        client.routing.post_create_route_if_not_exists(route_kwargs)
        successes.append("Route created: route=%s" % str(route_kwargs))

    if errors:
        raise CohesiveSDKException(",".join(errors))

    return successes, errors
Ejemplo n.º 10
0
def get_image_id_from_import(client, import_id):
    """Fetch Image ID given import uuid

    Arguments:
        client {VNS3Client}
        import_id {str}

    Returns:
        str - image ID
    """
    resp_data = client.network_edge_plugins.get_container_system_images(
        uuid=import_id)
    images = resp_data.response.images
    if not images:
        raise CohesiveSDKException("Couldnt find image for import id %s" %
                                   import_id)

    image = images[0]
    return image.get("id")
Ejemplo n.º 11
0
def _construct_peer_address_mapping(clients, address_type):
    """[summary]

    Arguments:
        clients {List[VNS3Client]}
        address_type {str} - one of primary_private_ip, secondary_private_ip, public_ip, public_dns

    Returns:
        List of tuples where first element is the client and second is a map for peers

        List[Tuple[VNS3Client, Dict]] -- [
            (client,  {
                [peer_id: str]: [peer_address: str],
                ...
            })
        ]
    """
    # fetch state if supported
    if attribute_supported(address_type):
        fetch_state_attribute(clients, address_type)

    client_indexes = range(len(clients))
    client_indexes_set = set(client_indexes)
    peer_address_mapping = []
    for index in client_indexes:
        this_client = clients[index]
        other_clients_indexes = client_indexes_set - {index}
        other_clients = [clients[i] for i in other_clients_indexes]
        client_peers = {
            c.query_state(VNS3Attr.peer_id): c.query_state(address_type)
            for c in other_clients
        }
        if not all(client_peers.values()):
            raise CohesiveSDKException(
                "Could not determine %s for some clients" % address_type)
        peer_address_mapping.append((this_client, client_peers))
    return peer_address_mapping
Ejemplo n.º 12
0
def search_containers(client: VNS3Client, image_id=None):
    """Search running plugins for one of image_id

    Arguments:
        client {VNS3Client}
        image_id {str}
    """
    if not any([image_id]):
        return []

    containers_resp = (
        client.network_edge_plugins.get_container_system_running_containers())
    containers = containers_resp.response.containers
    if containers is None:
        raise CohesiveSDKException("Container system is not running")
    if len(containers) == 0:
        return []

    matches = []
    for container in containers:
        if image_id is not None:
            if image_id == container.get("image"):
                matches.append(container)
    return matches
Ejemplo n.º 13
0
def __add_config_from_env(config, env_config):
    """Get configuration details from environment

    Arguments:
        config {dict}

    Raises:
        CohesiveSDKException: [description]

    Returns:
        [dict]
    """
    NONE_VALS = ("", None)
    topology_vars = env_config.pop("variables", {})
    topology_vars_non_null = _filter_dict_none_vals(topology_vars)
    topology_vars_plugin_images = env_config.pop("plugin_images", {})
    topology_vars_plugin_images_non_null = _filter_dict_none_vals(
        topology_vars_plugin_images
    )

    updated_config = dict(
        config,
        **{
            "variables": dict(config.get("variables", {}), **topology_vars_non_null),
            "plugin_images": dict(
                config.get("plugin_images", {}), **topology_vars_plugin_images_non_null
            ),
        }
    )

    controllers = config["controllers"]
    for key, config_value in env_config.items():
        if config_value in NONE_VALS:
            continue

        if key.startswith("controllers"):
            assert (
                type(config_value) is list
            ), "Controller state vars should be passed as lists"
            var_name = "_".join(key.split("_")[1:]).strip("_")
            for i, controller_var_value in enumerate(config_value):
                if controller_var_value in NONE_VALS:
                    continue

                if len(controllers) < i:
                    controllers.append({})

                controller_config = controllers[i]
                controller_vars = controller_config.get("variables", {})
                controller_vars.update(**{var_name: controller_var_value})
                controller_config.update(variables=controller_vars)

        elif type(config_value) in (str, int):
            updated_config.update(**{key: config_value})
        else:
            raise CohesiveSDKException(
                "Unknown environment variable value key=%s, value=%s"
                % (key, config_value)
            )

    updated_config.update(controllers=controllers)
    return updated_config
Ejemplo n.º 14
0
def setup_controller(
    client: VNS3Client,
    topology_name: str,
    license_file: str,
    license_parameters: Dict,
    keyset_parameters: Dict,
    controller_name: str = None,
    peering_id: int = 1,
    reboot_timeout=120,
    keyset_timeout=120,
):
    """setup_controller Set the topology name, controller license, keyset and peering ID if provided

    Arguments:
        client {VNS3Client}
        topology_name {str}
        controller_name {str}
        keyset_parameters {Dict} -- UpdateKeysetRequest {
            'source': 'str',
            'token': 'str',
            'topology_name': 'str',
            'sealed_network': 'bool'
        }
        peering_id {int} -- ID for this controller in peering mesh

    Returns:
        OperationResult
    """
    current_config = client.config.get_config().response
    Logger.info("Setting topology name", name=topology_name)
    config_update = {}
    if topology_name and current_config.topology_name != topology_name:
        config_update.update({"topology_name": topology_name})
    if controller_name and current_config.controller_name != controller_name:
        config_update.update({"controller_name": controller_name})
    if config_update:
        client.config.put_config(**config_update)

    if not current_config.licensed:
        if not os.path.isfile(license_file):
            raise CohesiveSDKException("License file does not exist")

        with open(license_file) as f:
            license_file_data = f.read().strip()
        Logger.info("Uploading license file", path=license_file)
        client.licensing.upload_license(license_file_data)

    accept_license = False
    try:
        current_license = client.licensing.get_license().response
        accept_license = not current_license or not current_license.finalized
    except ApiException as e:
        if e.get_error_message() == "Must be licensed first.":
            accept_license = True
        else:
            raise e

    if accept_license:
        Logger.info("Accepting license", parameters=license_parameters)
        client.licensing.put_set_license_parameters(**license_parameters)
        Logger.info("Waiting for server reboot.")
        client.sys_admin.wait_for_api(timeout=reboot_timeout, wait_for_reboot=True)

    current_keyset = api_operations.retry_call(
        client.config.get_keyset, max_attempts=20
    ).response
    if not current_keyset.keyset_present and not current_keyset.in_progress:
        Logger.info("Generating keyset", parameters=keyset_parameters)
        api_operations.retry_call(client.config.put_keyset, kwargs=keyset_parameters)
        Logger.info("Waiting for keyset ready")
        client.config.wait_for_keyset(timeout=keyset_timeout)
    elif current_keyset.in_progress:
        client.config.wait_for_keyset(timeout=keyset_timeout)

    current_peering_status = client.peering.get_peering_status().response
    if not current_peering_status.id and peering_id:
        Logger.info("Setting peering id", id=peering_id)
        client.peering.put_self_peering_id(**{"id": peering_id})
    return client
Ejemplo n.º 15
0
def peer_mesh(
    clients,
    peer_address_map=None,
    address_type=VNS3Attr.primary_private_ip,
    delay_configure=False,
    mtu=None,
):
    """peer_mesh Create a peering mesh by adding each client as peer for other clients.
       The order of the list of clients is the assumed peering id, i.e. client at clients[0]
       has peering id of 1, clients[1] has peering id of 2. Each TLS connection between peers
       is then automatically negotiated.

    Arguments:
        clients {List[VNS3Client]}

    Keyword Arguments:
        peer_address_map {Dict} - Optional map for peering addresses {
            [from_peer_id: str]: {
                [to_peer_id_1: str]: [peer_address_1: str],
                [to_peer_id_2: str]: [peer_address_2: str],
                ...
            }
        }
        address_type {str} - which address to use. Options: primary_private_ip, secondary_private_ip, public_ip or public_dns
        delay_configure {bool} -- delay automatic negotiation of peer (default: False)
        mtu {int} -- Override MTU for the peering TLS connection. VNS3 defaults to 1500. (default: {None})

    Raises:
        CohesiveSDKException

    Returns:
        data_types.BulkOperationResult
    """
    # fetch peer ids and set on clients
    ensure_peer_ids_result = fetch_state_attribute(clients, VNS3Attr.peer_id)
    if api_ops.bulk_operation_failed(ensure_peer_ids_result):
        errors_str = api_ops.stringify_bulk_result_exception(
            ensure_peer_ids_result)
        Logger.error("Failed to fetch peering Ids for all clients",
                     errors=errors_str)
        raise CohesiveSDKException(
            "Failed to fetch peering Ids for all clients: %s" % errors_str)

    # constructu peer address mapping
    if peer_address_map is not None:
        Logger.debug("Using address map passed for peering mesh.")
        peer_id_to_client = {
            c.query_state(VNS3Attr.peer_id): c
            for c in clients
        }
        peer_address_mapping_tuples = [
            (peer_id_to_client[from_peer_id], to_peers_map)
            for from_peer_id, to_peers_map in peer_address_map.items()
        ]
    else:
        Logger.debug("Constructing peering mesh")
        peer_address_mapping_tuples = _construct_peer_address_mapping(
            clients, address_type)

    common_peer_kwargs = {}
    if delay_configure:
        common_peer_kwargs.update(force=False)
    if mtu:
        common_peer_kwargs.update(overlay_mtu=mtu)

    def create_all_peers_for_client(client, post_peer_kwargs):
        return [
            client.peering.post_create_peer(
                **dict(peering_request, **common_peer_kwargs))
            for peering_request in post_peer_kwargs
        ]

    run_peering_funcs = []
    # bind api function calls for peer creations
    for vns3_client, peer_mapping in peer_address_mapping_tuples:
        run_peering_funcs.append(
            bind(
                create_all_peers_for_client,
                vns3_client,
                [{
                    "id": peer_id,
                    "name": peer_address
                } for peer_id, peer_address in peer_mapping.items()],
            ))

    Logger.debug("Creating %d-way peering mesh." % len(clients))
    return api_ops.__bulk_call_api(run_peering_funcs, parallelize=True)
def configure_multicloud_bridge_client(**bridge_kwargs):
    """Configure client for multicloud IPsec bridge

    Arguments:
        target_client {VNS3Client} - client to be configured
        target_topology_name {str}
        peer_endpoint {str} - controller endpoint on otherside of bridge
        endpoint_name {str} - name for the IPsec endpoint
        tunnel_vti {str} - CIDR to be used for VTI interface

        license_file {str} - full path to license file
        target_cidr {str} - cidr for this clients network
        peer_cidr {str} - cidr accessible on other side of bridge
        tunnel_psk {str} -- Preshared key for IPsec tunnel

    Returns:
        Dict - {
            endpoint: IpsecRemoteEndpoint,
            routes: Dict
        } OR Exception
    """
    required_kwargs = [
        "target_client",
        "target_topology_name",
        "peer_endpoint",
        "endpoint_name",
        "tunnel_vti",
        "license_file",
        "keyset_token",
        "target_cidr",
        "peer_cidr",
        "tunnel_psk",
    ]

    missing_kwargs = [a for a in required_kwargs if a not in bridge_kwargs]
    if len(missing_kwargs) > 0:
        return CohesiveSDKException("Missing args for bridge %s" %
                                    missing_kwargs)

    try:
        target_client = bridge_kwargs["target_client"]
        topology_name = bridge_kwargs["target_topology_name"]
        print("Setup for %s..." % topology_name)
        config.setup_controller(
            target_client,
            topology_name,
            bridge_kwargs["license_file"],
            license_parameters={"default": True},
            keyset_parameters={"token": bridge_kwargs["keyset_token"]},
            reboot_timeout=240,
            keyset_timeout=240,
        )

        target_cidr = bridge_kwargs["target_cidr"]
        print("Creating local gateway routes for %s" % target_cidr)
        routing.create_local_gateway_route(target_client,
                                           target_cidr,
                                           should_raise=False)

        endpoint_name = bridge_kwargs["endpoint_name"]
        print("Creating tunnel: %s" % endpoint_name)
        return ipsec.create_tunnel_endpoint(
            target_client,
            endpoint_name,
            bridge_kwargs["tunnel_psk"],
            bridge_kwargs["peer_endpoint"],
            bridge_kwargs["peer_cidr"],
            bridge_kwargs["tunnel_vti"],
        )
    except (ApiException, CohesiveSDKException) as e:
        return e
Ejemplo n.º 17
0
def assert_rule_policy(client: VNS3Client, rules, should_fix=False):
    """Assert rule policy contains expected rules

    Arguments:
        client {VNS3Client}
        rules {List[dict]}

    Keyword Arguments:
        should_fix {bool} - if false, raise Error, else, update firewall

    Raises:
        CohesiveSDKException - raised if invalid firewall rules provided
        AssertionError - raised if should_fix=False and provided rules dont match VNS3

    Returns:
        List[str] - ordered list of firewall rules
    """
    current_firewall = __firewall_resp_to_list(
        client.firewall.get_firewall_rules())
    new_firewall, errors = __construct_proposed_firewall_list(
        rules, state=client.state)
    if errors:
        raise CohesiveSDKException(
            "Invalid firewall rules provided. Errors=%s" % (errors))

    if current_firewall == new_firewall:
        Logger.info("Current firewall is correct. No-op.",
                    host=client.host_uri)
        return current_firewall

    Logger.info(
        "Firewall configuration drift. Expected: %s != %s." %
        (new_firewall, current_firewall),
        host=client.host_uri,
    )

    if not should_fix:
        raise AssertionError(
            "Firewalls did not match for VNS3 @ %s. Current firewall %s != %s."
            % (client.host_uri, current_firewall, new_firewall))

    # operations: insert, delete
    OP_INS = "insert"
    OP_DEL = "delete"
    firewall_edits = []
    for i, rule in enumerate(new_firewall):
        if len(current_firewall) <= i:
            operation = OP_INS
        elif current_firewall[i] == rule:
            continue
        else:
            # current firewall rule is incorrect.
            # now, minimize operations to get correct
            # if can insert OR delete, prefer delete
            # ie. if next rule is the correct rule, del this rule
            operation = (OP_DEL if len(current_firewall) > i + 1
                         and current_firewall[i + 1] == rule else OP_INS)

        firewall_edits.append("%s:%s" % (operation, i))
        if operation == OP_INS:
            client.firewall.post_create_firewall_rule(**{
                "rule": rule,
                "position": i
            })
            current_firewall.insert(i, rule)
        else:  # operation == OP_DEL:
            client.firewall.delete_firewall_rule_by_position(i)
            del current_firewall[i]

    Logger.debug(
        "%s network operations required to fix firewall: %s" %
        (len(firewall_edits), firewall_edits),
        host=client.host_uri,
    )

    return __firewall_resp_to_list(client.firewall.get_firewall_rules())
Ejemplo n.º 18
0
def fetch_keyset_from_source(  # noqa: C901
    client, source, token, wait_timeout=180.0, allow_exists=False
):  # noqa
    """fetch_keyset_from_source Put keyset by providing source controller to download keyset. This
    contains logic that handles whether or not fetching from the source fails, typically due
    to a firewall or routing issue in the underlay network (e.g. security groups and route tables).

    Pseudo-logic:
        PUT new keyset request to fetch from remote controller
        if keyset exists or already in progress, fail immediately as its unexpected
        if PUT succees:
            wait:
                if a new successful put returns: that indicates failure to download from source. return 400
                if timeout: that indicates controller is rebooting, return wait for keyset
                if keyset already exists, wait to ensure keyset  exists, then return keyset details

    Arguments:
        source {str} - host controller to fetch keyset from
        token {str} - secret token used when generating keyset
        wait_timeout {float} - timeout for waiting for keyset and while polling for download failure (default: 1 min)
        allow_exists {bool} - If true and keyset already exists, DONT throw exception

    Raises:
        e: ApiException or CohesiveSDKException

    Returns:
        KeysetDetail
    """
    sleep_time = 2.0
    failure_error_str = (
        "Failed to fetch keyset for source. This typically due to a misconfigured "
        "firewall or routing issue between source and client controllers."
    )

    try:
        put_response = client.config.put_keyset(**{"source": source, "token": token})
    except ApiException as e:
        if allow_exists and ("keyset already exists" in e.get_error_message().lower()):
            Logger.info("Keyset already exists.", host=client.host_uri)
            return client.config.try_get_keyset()

        Logger.info(
            "Failed to fetch keyset: %s" % e.get_error_message(),
            host=client.host_uri,
        )
        raise e
    except UrlLib3ConnExceptions:
        raise ApiException(
            status=HTTPStatus.SERVICE_UNAVAILABLE,
            reason="Controller unavailable. It is likely rebooting. Try client.sys_admin.wait_for_api().",
        )

    if not put_response.response:
        keyset_data = client.config.get_keyset()
        if keyset_data.response and keyset_data.response.keyset_present:
            raise ApiException(status=400, reason="Keyset already exists.")
        raise ApiException(status=500, reason="Put keyset returned None.")

    start_time = put_response.response.started_at_i
    Logger.info(message="Keyset downloading from source.", start_time=start_time)
    polling_start = time.time()
    while time.time() - polling_start <= wait_timeout:
        try:
            duplicate_call_resp = client.config.put_keyset(
                **{"source": source, "token": token}
            )
        except UrlLib3ConnExceptions:
            Logger.info(
                "API call timeout. Controller is likely rebooting. Waiting for keyset.",
                wait_timeout=wait_timeout,
                source=source,
            )
            client.sys_admin.wait_for_api(timeout=wait_timeout, wait_for_reboot=True)
            return client.config.wait_for_keyset(timeout=wait_timeout)
        except ApiException as e:
            duplicate_call_error = e.get_error_message()

            if duplicate_call_error == "Keyset already exists.":
                keyset_data = client.config.try_get_keyset()
                if not keyset_data:
                    Logger.info(
                        "Keyset exists. Waiting for reboot.",
                        wait_timeout=wait_timeout,
                        source=source,
                    )
                    client.sys_admin.wait_for_api(
                        timeout=wait_timeout, wait_for_reboot=True
                    )
                    return client.config.wait_for_keyset()
                return keyset_data

            if duplicate_call_error == "Keyset setup in progress.":
                # this means download is in progress, but might fail. Wait and retry
                time.sleep(sleep_time)
                continue

            # Unexpected ApiException
            raise e

        # If there is a new start time for keyset generation, that indicates a failure
        new_start_resp = duplicate_call_resp.response
        new_start = new_start_resp.started_at_i if new_start_resp else None
        if (new_start and start_time) and (new_start != start_time):
            Logger.error(failure_error_str, source=source)
            raise ApiException(status=HTTPStatus.BAD_REQUEST, reason=failure_error_str)

        time.sleep(sleep_time)
    raise CohesiveSDKException("Timeout while waiting for keyset.")