Ejemplo n.º 1
0
def ztp_device_cert(task, job_id: str, new_hostname: str,
                    management_ip: str) -> str:
    set_thread_data(job_id)
    logger = get_logger()

    try:
        ipv4: IPv4Address = IPv4Address(management_ip)
        generate_device_cert(new_hostname, ipv4_address=ipv4)
    except Exception as e:
        raise Exception(
            "Could not generate certificate for device {}: {}".format(
                new_hostname, e))

    if task.host.platform == "eos":
        try:
            # TODO: subtaskerror?
            res = task.run(task=arista_copy_cert, job_id=job_id)
        except Exception as e:
            logger.exception('Exception while copying certificates: {}'.format(
                str(e)))
            raise e
    else:
        return "Install device certificate not supported on platform: {}".format(
            task.host.platform)
    return "Device certificate installed for {}".format(new_hostname)
Ejemplo n.º 2
0
def push_static_config(task, config: str, dry_run: bool = True,
                       job_id: Optional[str] = None,
                       scheduled_by: Optional[str] = None):
    """
    Nornir task to push static config to device

    Args:
        task: nornir task, sent by nornir when doing .run()
        config: static config to apply
        dry_run: Don't commit config to device, just do compare/diff
        scheduled_by: username that triggered job

    Returns:
    """
    set_thread_data(job_id)
    logger = get_logger()

    logger.debug("Push static config to device: {}".format(task.host.name))

    task.run(task=napalm_configure,
             name="Push static config",
             replace=True,
             configuration=config,
             dry_run=dry_run
             )
Ejemplo n.º 3
0
def arista_pre_flight_check(task, job_id: Optional[str] = None) -> str:
    """
    NorNir task to do some basic checks before attempting to upgrade a switch.

    Args:
        task: NorNir task

    Returns:
        String, describing the result

    """
    set_thread_data(job_id)
    logger = get_logger()
    with sqla_session() as session:
        if Job.check_job_abort_status(session, job_id):
            return "Pre-flight aborted"

    flash_diskspace = 'bash timeout 5 df /mnt/flash | awk \'{print $4}\''
    flash_cleanup = 'bash timeout 30 ls -t /mnt/flash/*.swi | tail -n +2 | grep -v `cut -d"/" -f2 /mnt/flash/boot-config` | xargs rm -f'

    # Get amount of free disk space
    res = task.run(napalm_cli, commands=[flash_diskspace])
    if not isinstance(res, MultiResult) or len(res.result.keys()) != 1:
        raise Exception('Could not check free space')

    # Remove old firmware images if needed
    free_bytes = next(iter(res.result.values())).split('\n')[1]
    if int(free_bytes) < 2500000:
        logger.info('Cleaning up old firmware images on {}'.format(
            task.host.name))
        res = task.run(napalm_cli, commands=[flash_cleanup])
    else:
        logger.info('Enough free space ({}b), no cleanup'.format(free_bytes))

    return "Pre-flight check done."
Ejemplo n.º 4
0
def sync_check_hash(task, force=False, job_id=None):
    """
    Start the task which will compare device configuration hashes.

    Args:
        task: Nornir task
        force: Ignore device hash
    """
    set_thread_data(job_id)
    logger = get_logger()
    if force is True:
        return
    with sqla_session() as session:
        stored_hash = Device.get_config_hash(session, task.host.name)
    if stored_hash is None:
        return

    task.host.open_connection("napalm", configuration=task.nornir.config)
    res = task.run(task=napalm_get, getters=["config"])
    task.host.close_connection("napalm")

    running_config = dict(res.result)['config']['running'].encode()
    if running_config is None:
        raise Exception('Failed to get running configuration')
    hash_obj = sha256(running_config)
    running_hash = hash_obj.hexdigest()
    if stored_hash != running_hash:
        raise Exception('Device {} configuration is altered outside of CNaaS!'.format(task.host.name))
Ejemplo n.º 5
0
def renew_cert_task(task, job_id: str) -> str:
    set_thread_data(job_id)
    logger = get_logger()

    with sqla_session() as session:
        dev: Device = session.query(Device). \
            filter(Device.hostname == task.host.name).one_or_none()
        ip = dev.management_ip
        if not ip:
            raise Exception("Device {} has no management_ip".format(
                task.host.name))

    try:
        generate_device_cert(task.host.name, ipv4_address=ip)
    except Exception as e:
        raise Exception(
            "Could not generate certificate for device {}: {}".format(
                task.host.name, e))

    if task.host.platform == "eos":
        try:
            res = task.run(task=arista_copy_cert, job_id=job_id)
        except Exception as e:
            logger.exception('Exception while copying certificates: {}'.format(
                str(e)))
            raise e
    else:
        raise ValueError("Unsupported platform: {}".format(task.host.platform))

    return "Certificate renew success for device {}".format(task.host.name)
Ejemplo n.º 6
0
def push_base_management_access(task, device_variables, job_id):
    set_thread_data(job_id)
    logger = get_logger()
    logger.debug("Push basetemplate for host: {}".format(task.host.name))

    with open('/etc/cnaas-nms/repository.yml', 'r') as db_file:
        repo_config = yaml.safe_load(db_file)
        local_repo_path = repo_config['templates_local']

    mapfile = os.path.join(local_repo_path, task.host.platform, 'mapping.yml')
    if not os.path.isfile(mapfile):
        raise RepoStructureException(
            "File {} not found in template repo".format(mapfile))
    with open(mapfile, 'r') as f:
        mapping = yaml.safe_load(f)
        template = mapping['ACCESS']['entrypoint']

    settings, settings_origin = get_settings(task.host.name, DeviceType.ACCESS)

    # Add all environment variables starting with TEMPLATE_SECRET_ to
    # the list of configuration variables. The idea is to store secret
    # configuration outside of the templates repository.
    template_secrets = {}
    for env in os.environ:
        if env.startswith('TEMPLATE_SECRET_'):
            template_secrets[env] = os.environ[env]

    # Merge dicts, this will overwrite interface list from settings
    template_vars = {**settings, **device_variables, **template_secrets}

    r = task.run(task=text.template_file,
                 name="Generate initial device config",
                 template=template,
                 path=f"{local_repo_path}/{task.host.platform}",
                 **template_vars)

    #TODO: Handle template not found, variables not defined

    task.host["config"] = r.result
    # Use extra low timeout for this since we expect to loose connectivity after changing IP
    task.host.connection_options["napalm"] = ConnectionOptions(
        extras={"timeout": 30})

    try:
        task.run(task=networking.napalm_configure,
                 name="Push base management config",
                 replace=True,
                 configuration=task.host["config"],
                 dry_run=False)
    except Exception:
        task.run(task=networking.napalm_get, getters=["facts"])
        if not task.results[-1].failed:
            raise InitError(
                "Device {} did not commit new base management config".format(
                    task.host.name))
Ejemplo n.º 7
0
def arista_post_flight_check(task,
                             post_waittime: int,
                             job_id: Optional[str] = None) -> str:
    """
    NorNir task to update device facts after a switch have been upgraded

    Args:
        task: NorNir task
        post_waittime: Time to wait before trying to gather facts

    Returns:
        String, describing the result

    """
    set_thread_data(job_id)
    logger = get_logger()
    time.sleep(int(post_waittime))
    logger.info(
        'Post-flight check wait ({}s) complete, starting check for {}'.format(
            post_waittime, task.host.name))
    with sqla_session() as session:
        if Job.check_job_abort_status(session, job_id):
            return "Post-flight aborted"

    try:
        res = task.run(napalm_get, getters=["facts"])
        os_version = res[0].result['facts']['os_version']

        with sqla_session() as session:
            dev: Device = session.query(Device).filter(
                Device.hostname == task.host.name).one()
            prev_os_version = dev.os_version
            dev.os_version = os_version
            if prev_os_version == os_version:
                logger.error(
                    "OS version did not change, activation failed on {}".
                    format(task.host.name))
                raise Exception("OS version did not change, activation failed")
            else:
                dev.confhash = None
                dev.synchronized = False
    except Exception as e:
        logger.exception("Could not update OS version on device {}: {}".format(
            task.host.name, str(e)))
        return 'Post-flight failed, could not update OS version: {}'.format(
            str(e))

    return "Post-flight, OS version updated from {} to {}.".format(
        prev_os_version, os_version)
Ejemplo n.º 8
0
def arista_device_reboot(task, job_id: Optional[str] = None) -> str:
    """
    NorNir task to reboot a single device.

    Args:
        task: NorNir task.

    Returns:
        String, describing the result

    """
    set_thread_data(job_id)
    logger = get_logger()
    with sqla_session() as session:
        if Job.check_job_abort_status(session, job_id):
            return "Reboot aborted"

    try:
        res = task.run(netmiko_send_command,
                       command_string='enable',
                       expect_string='.*#')

        res = task.run(netmiko_send_command,
                       command_string='write',
                       expect_string='.*#')

        res = task.run(netmiko_send_command,
                       command_string='reload force',
                       max_loops=2,
                       expect_string='.*')
    except Exception as e:
        logger.exception('Failed to reboot switch {}: {}'.format(
            task.host.name, str(e)))
        raise e

    return "Device reboot done."
Ejemplo n.º 9
0
 def wrapper(job_id: int, scheduled_by: str, kwargs={}):
     if not job_id or not type(job_id) == int:
         errmsg = "Missing job_id when starting job for {}".format(
             func.__name__)
         logger.error(errmsg)
         raise ValueError(errmsg)
     progress_funcitons = ['sync_devices', 'device_upgrade']
     with sqla_session() as session:
         job = session.query(Job).filter(Job.id == job_id).one_or_none()
         if not job:
             errmsg = "Could not find job_id {} in database".format(job_id)
             logger.error(errmsg)
             raise ValueError(errmsg)
         kwargs['job_id'] = job_id
         # Don't send new function name unless it was set to "wrapper"
         function_name = None
         if job.function_name == "wrapper":
             function_name = func.__name__
         job.start_job(function_name=function_name)
         if func.__name__ in progress_funcitons:
             stop_event = threading.Event()
             device_thread = threading.Thread(
                 target=update_device_progress_thread,
                 args=(stop_event, job_id))
             device_thread.start()
     try:
         set_thread_data(job_id)
         # kwargs is contained in an item called kwargs because of the apscheduler.add_job call
         res = func(**kwargs)
         if job_id:
             res = insert_job_id(res, job_id)
         del thread_data.job_id
     except Exception as e:
         tb = traceback.format_exc()
         logger.debug("Exception traceback in job_wrapper: {}".format(tb))
         with sqla_session() as session:
             job = session.query(Job).filter(Job.id == job_id).one_or_none()
             if not job:
                 errmsg = "Could not find job_id {} in database".format(
                     job_id)
                 logger.error(errmsg)
                 raise ValueError(errmsg)
             if func.__name__ in progress_funcitons:
                 stop_event.set()
             job.finish_exception(e, tb)
             session.commit()
         raise e
     else:
         if func.__name__ in progress_funcitons:
             stop_event.set()
             device_thread.join()
         with sqla_session() as session:
             job = session.query(Job).filter(Job.id == job_id).one_or_none()
             if not job:
                 errmsg = "Could not find job_id {} in database".format(
                     job_id)
                 logger.error(errmsg)
                 raise ValueError(errmsg)
             job.finish_success(res, find_nextjob(res))
             session.commit()
         return res
Ejemplo n.º 10
0
def push_base_management(task, device_variables: dict, devtype: DeviceType,
                         job_id):
    set_thread_data(job_id)
    logger = get_logger()
    logger.debug("Push basetemplate for host: {}".format(task.host.name))

    with open('/etc/cnaas-nms/repository.yml', 'r') as db_file:
        repo_config = yaml.safe_load(db_file)
        local_repo_path = repo_config['templates_local']

    mapfile = os.path.join(local_repo_path, task.host.platform, 'mapping.yml')
    if not os.path.isfile(mapfile):
        raise RepoStructureException(
            "File {} not found in template repo".format(mapfile))
    with open(mapfile, 'r') as f:
        mapping = yaml.safe_load(f)
        template = mapping[devtype.name]['entrypoint']

    # TODO: install device certificate, using new hostname and reserved IP.
    #       exception on fail if tls_verify!=False
    try:
        device_cert_res = task.run(task=ztp_device_cert,
                                   job_id=job_id,
                                   new_hostname=task.host.name,
                                   management_ip=device_variables['mgmt_ip'])
    # TODO: handle exception from ztp_device_cert -> arista_copy_cert
    except Exception as e:
        logger.exception(e)
    else:
        if device_cert_res.failed:
            if device_cert_required():
                logger.error(
                    "Unable to install device certificate for {}, aborting".
                    format(device_variables['host']))
                raise Exception(device_cert_res[0].exception)
            else:
                logger.debug(
                    "Unable to install device certificate for {}".format(
                        device_variables['host']))

    r = task.run(task=template_file,
                 name="Generate initial device config",
                 template=template,
                 jinja_env=cnaas_jinja_env,
                 path=f"{local_repo_path}/{task.host.platform}",
                 **device_variables)

    #TODO: Handle template not found, variables not defined

    task.host["config"] = r.result
    # Use extra low timeout for this since we expect to loose connectivity after changing IP
    connopts_napalm = task.host.connection_options["napalm"]
    connopts_napalm.extras["timeout"] = 30

    try:
        task.run(task=napalm_configure,
                 name="Push base management config",
                 replace=True,
                 configuration=task.host["config"],
                 dry_run=False)
    except Exception:
        task.run(task=napalm_get, getters=["facts"])
        if not task.results[-1].failed:
            raise InitError(
                "Device {} did not commit new base management config".format(
                    task.host.name))
Ejemplo n.º 11
0
def push_sync_device(task, dry_run: bool = True, generate_only: bool = False,
                     job_id: Optional[str] = None,
                     scheduled_by: Optional[str] = None):
    """
    Nornir task to generate config and push to device

    Args:
        task: nornir task, sent by nornir when doing .run()
        dry_run: Don't commit config to device, just do compare/diff
        generate_only: Only generate text config, don't try to commit or
                       even do dry_run compare to running config

    Returns:

    """
    set_thread_data(job_id)
    logger = get_logger()
    hostname = task.host.name
    with sqla_session() as session:
        dev: Device = session.query(Device).filter(Device.hostname == hostname).one()
        template_vars = populate_device_vars(session, dev)
        platform = dev.platform
        devtype = dev.device_type

    with open('/etc/cnaas-nms/repository.yml', 'r') as db_file:
        repo_config = yaml.safe_load(db_file)
        local_repo_path = repo_config['templates_local']

    mapfile = os.path.join(local_repo_path, platform, 'mapping.yml')
    if not os.path.isfile(mapfile):
        raise RepoStructureException("File {} not found in template repo".format(mapfile))
    with open(mapfile, 'r') as f:
        mapping = yaml.safe_load(f)
        template = mapping[devtype.name]['entrypoint']

    logger.debug("Generate config for host: {}".format(task.host.name))
    r = task.run(task=template_file,
                 name="Generate device config",
                 template=template,
                 jinja_env=cnaas_jinja_env,
                 path=f"{local_repo_path}/{task.host.platform}",
                 **template_vars)

    # TODO: Handle template not found, variables not defined
    # jinja2.exceptions.UndefinedError

    task.host["config"] = r.result
    task.host["template_vars"] = template_vars

    if generate_only:
        task.host["change_score"] = 0
    else:
        logger.debug("Synchronize device config for host: {} ({}:{})".format(
            task.host.name, task.host.hostname, task.host.port))

        task.host.open_connection("napalm", configuration=task.nornir.config)
        task.run(task=napalm_configure,
                 name="Sync device config",
                 replace=True,
                 configuration=task.host["config"],
                 dry_run=dry_run
                 )
        task.host.close_connection("napalm")

        if task.results[1].diff:
            config = task.results[1].host["config"]
            diff = task.results[1].diff
            task.host["change_score"] = calculate_score(config, diff)
        else:
            task.host["change_score"] = 0
    if job_id:
        with redis_session() as db:
            db.lpush('finished_devices_' + str(job_id), task.host.name)
Ejemplo n.º 12
0
def arista_copy_cert(task, job_id: Optional[str] = None) -> str:
    set_thread_data(job_id)
    logger = get_logger()
    apidata = get_apidata()

    try:
        key_path = os.path.join(apidata['certpath'],
                                "{}.key".format(task.host.name))
        crt_path = os.path.join(apidata['certpath'],
                                "{}.crt".format(task.host.name))
    except KeyError:
        raise Exception("No certpath found in api.yml settings")
    except Exception as e:
        raise Exception("Unable to find path to cert {} for device".format(
            e, task.host.name))

    if not os.path.isfile(key_path):
        raise Exception("Key file {} not found".format(key_path))
    if not os.path.isfile(crt_path):
        raise Exception("Cert file {} not found".format(crt_path))

    net_connect = task.host.get_connection("netmiko", task.nornir.config)
    net_connect.fast_cli = False

    res_key = task.run(netmiko_file_transfer,
                       source_file=key_path,
                       dest_file="cnaasnms.key",
                       file_system="/mnt/flash",
                       overwrite_file=True)
    if res_key.failed:
        logger.exception(res_key.exception)

    res_crt = task.run(netmiko_file_transfer,
                       source_file=crt_path,
                       dest_file="cnaasnms.crt",
                       file_system="/mnt/flash",
                       overwrite_file=True)
    if res_crt.failed:
        logger.exception(res_crt.exception)

    if res_key.failed or res_crt.failed:
        raise CopyError("Unable to copy cert file to device: {}".format(
            task.host.name))
    else:
        logger.debug("Certificate successfully copied to device: {}".format(
            task.host.name))

    certstore_commands = [
        "copy flash:cnaasnms.crt certificate:",
        "copy flash:cnaasnms.key sslkey:", "delete flash:cnaasnms.key",
        "delete flash:cnaasnms.crt"
    ]
    for cmd in certstore_commands:
        res_certstore = task.run(netmiko_send_command,
                                 command_string=cmd,
                                 enable=True)
        if res_certstore.failed:
            logger.error(
                "Unable to copy cert into certstore on device: {}, command '{}' failed"
                .format(task.host.name, cmd))
            raise CopyError(
                "Unable to copy cert into certstore on device: {}".format(
                    task.host.name))

    logger.debug(
        "Certificate successfully copied to certstore on device: {}".format(
            task.host.name))
    return "Cert copy successful"
Ejemplo n.º 13
0
def device_upgrade_task(task,
                        job_id: str,
                        filename: str,
                        url: str,
                        reboot: Optional[bool] = False,
                        download: Optional[bool] = False,
                        pre_flight: Optional[bool] = False,
                        post_flight: Optional[bool] = False,
                        post_waittime: Optional[int] = 0,
                        activate: Optional[bool] = False) -> NornirJobResult:

    # If pre-flight is selected, execute the pre-flight task which
    # will verify the amount of disk space and so on.
    set_thread_data(job_id)
    logger = get_logger()
    if pre_flight:
        logger.info('Running pre-flight check on {}'.format(task.host.name))
        try:
            res = task.run(task=arista_pre_flight_check, job_id=job_id)
        except Exception as e:
            logger.exception(
                "Exception while doing pre-flight check: {}".format(str(e)))
            raise Exception('Pre-flight check failed')
        else:
            if res.failed:
                logger.exception('Pre-flight check failed for: {}'.format(
                    ' '.join(res.failed_hosts.keys())))
                raise

    # If download is true, go ahead and download the firmware
    if download:
        # Download the firmware from the HTTP container.
        logger.info('Downloading firmware {} on {}'.format(
            filename, task.host.name))
        try:
            res = task.run(task=arista_firmware_download,
                           filename=filename,
                           httpd_url=url,
                           job_id=job_id)
        except Exception as e:
            logger.exception('Exception while downloading firmware: {}'.format(
                str(e)))
            raise e

    # If download_only is false, continue to activate the newly downloaded
    # firmware and verify that it if present in the boot-config.
    already_active = False
    if activate:
        logger.info('Activating firmware {} on {}'.format(
            filename, task.host.name))
        try:
            res = task.run(task=arista_firmware_activate,
                           filename=filename,
                           job_id=job_id)
        except NornirSubTaskError as e:
            subtask_result = e.result[0]
            logger.debug(
                'Exception while activating firmware for {}: {}'.format(
                    task.host.name, subtask_result))
            if subtask_result.exception:
                if isinstance(subtask_result.exception,
                              FirmwareAlreadyActiveException):
                    already_active = True
                    logger.info(
                        "Firmware already active, skipping reboot and post_flight: {}"
                        .format(subtask_result.exception))
                else:
                    logger.exception(
                        'Firmware activate subtask exception for {}: {}'.
                        format(task.host.name, str(subtask_result.exception)))
                    raise e
            else:
                logger.error('Activate subtask result for {}: {}'.format(
                    task.host.name, subtask_result.result))
                raise e
        except Exception as e:
            logger.exception(
                'Exception while activating firmware for {}: {}'.format(
                    task.host.name, str(e)))
            raise e

    # Reboot the device if needed, we will then lose the connection.
    if reboot and not already_active:
        logger.info('Rebooting {}'.format(task.host.name))
        try:
            res = task.run(task=arista_device_reboot, job_id=job_id)
        except Exception as e:
            pass

    # If post-flight is selected, execute the post-flight task which
    # will update device facts for the selected devices
    if post_flight and not already_active:
        logger.info(
            'Running post-flight check on {}, delay start by {}s'.format(
                task.host.name, post_waittime))
        try:
            res = task.run(task=arista_post_flight_check,
                           post_waittime=post_waittime,
                           job_id=job_id)
        except Exception as e:
            logger.exception('Failed to run post-flight check: {}'.format(
                str(e)))
        else:
            if res.failed:
                logger.error('Post-flight check failed for: {}'.format(
                    ' '.join(res.failed_hosts.keys())))

    if job_id:
        with redis_session() as db:
            db.lpush('finished_devices_' + str(job_id), task.host.name)
Ejemplo n.º 14
0
def arista_firmware_activate(task,
                             filename: str,
                             job_id: Optional[str] = None) -> str:
    """
    NorNir task to modify the boot config for new firmwares.

    Args:
        task: NorNir task
        filename: Name of the new firmware image

    Returns:
        String, describing the result

    """
    set_thread_data(job_id)
    logger = get_logger()
    with sqla_session() as session:
        if Job.check_job_abort_status(session, job_id):
            return "Firmware activate aborted"

    try:
        boot_file_cmd = 'boot system flash:{}'.format(filename)

        res = task.run(netmiko_send_command,
                       command_string='enable',
                       expect_string='.*#')

        res = task.run(
            netmiko_send_command,
            command_string='show boot-config | grep -o "\\w*{}\\w*"'.format(
                filename))
        if res.result == filename:
            raise FirmwareAlreadyActiveException(
                'Firmware already activated in boot-config on {}'.format(
                    task.host.name))

        res = task.run(netmiko_send_command,
                       command_string='conf t',
                       expect_string='.*config.*#')

        res = task.run(netmiko_send_command, command_string=boot_file_cmd)

        res = task.run(netmiko_send_command,
                       command_string='end',
                       expect_string='.*#')

        res = task.run(
            netmiko_send_command,
            command_string='show boot-config | grep -o "\\w*{}\\w*"'.format(
                filename))

        if not isinstance(res, MultiResult):
            raise Exception('Could not check boot-config on {}'.format(
                task.host.name))

        if res.result != filename:
            raise Exception('Firmware not activated properly on {}'.format(
                task.host.name))
    except FirmwareAlreadyActiveException as e:
        raise e
    except Exception as e:
        logger.exception('Failed to activate firmware on {}: {}'.format(
            task.host.name, str(e)))
        raise Exception('Failed to activate firmware')

    return "Firmware activate done."
Ejemplo n.º 15
0
def arista_firmware_download(task,
                             filename: str,
                             httpd_url: str,
                             job_id: Optional[str] = None) -> str:
    """
    NorNir task to download firmware image from the HTTP server.

    Args:
        task: NorNir task
        filename: Name of the file to download
        httpd_url: Base URL to the HTTP server

    Returns:
        String, describing the result

    """
    set_thread_data(job_id)
    logger = get_logger()
    with sqla_session() as session:
        if Job.check_job_abort_status(session, job_id):
            return "Firmware download aborted"

    url = httpd_url + '/' + filename
    # Make sure netmiko doesn't use fast_cli because it will change delay_factor
    # that is set in task.run below and cause early timeouts
    net_connect = task.host.get_connection("netmiko", task.nornir.config)
    net_connect.fast_cli = False

    try:
        with sqla_session() as session:
            dev: Device = session.query(Device).\
                filter(Device.hostname == task.host.name).one_or_none()
            device_type = dev.device_type

        if device_type == DeviceType.ACCESS:
            firmware_download_cmd = 'copy {} flash:'.format(url)
        else:
            firmware_download_cmd = 'copy {} vrf MGMT flash:'.format(url)

        res = task.run(netmiko_send_command,
                       command_string=firmware_download_cmd.replace("//", "/"),
                       enable=True,
                       delay_factor=30,
                       max_loops=200)

        if 'Copy completed successfully' in res.result:
            return "Firmware download done."
        else:
            logger.debug("Firmware download failed on {} ('{}'): {}".format(
                task.host.name, firmware_download_cmd, res.result))
            raise Exception(
                "Copy command did not complete successfully: {}".format(
                    ', '.join(
                        filter(lambda x: x.startswith('get:'),
                               res.result.splitlines()))))

    except NornirSubTaskError as e:
        subtask_result = e.result[0]
        logger.error('{} failed to download firmware: {}'.format(
            task.host.name, subtask_result))
        logger.debug('{} download subtask result: {}'.format(
            task.host.name, subtask_result.result))
        raise Exception(
            'Failed to download firmware: {}'.format(subtask_result))
    except Exception as e:
        logger.error('{} failed to download firmware: {}'.format(
            task.host.name, e))
        raise Exception('Failed to download firmware: {}'.format(e))

    return "Firmware download done."
Ejemplo n.º 16
0
def push_sync_device(task,
                     dry_run: bool = True,
                     generate_only: bool = False,
                     job_id: Optional[str] = None,
                     scheduled_by: Optional[str] = None):
    """
    Nornir task to generate config and push to device

    Args:
        task: nornir task, sent by nornir when doing .run()
        dry_run: Don't commit config to device, just do compare/diff
        generate_only: Only generate text config, don't try to commit or
                       even do dry_run compare to running config

    Returns:

    """
    set_thread_data(job_id)
    logger = get_logger()
    hostname = task.host.name
    with sqla_session() as session:
        dev: Device = session.query(Device).filter(
            Device.hostname == hostname).one()
        mgmt_ip = dev.management_ip
        infra_ip = dev.infra_ip
        if not mgmt_ip:
            raise Exception(
                "Could not find management IP for device {}".format(hostname))
        devtype: DeviceType = dev.device_type
        if isinstance(dev.platform, str):
            platform: str = dev.platform
        else:
            raise ValueError("Unknown platform: {}".format(dev.platform))
        settings, settings_origin = get_settings(hostname, devtype)
        device_variables = {
            'mgmt_ip': str(mgmt_ip),
            'device_model': dev.model,
            'device_os_version': dev.os_version
        }

        if devtype == DeviceType.ACCESS:
            mgmtdomain = cnaas_nms.db.helper.find_mgmtdomain_by_ip(
                session, dev.management_ip)
            if not mgmtdomain:
                raise Exception(
                    "Could not find appropriate management domain for management_ip: {}"
                    .format(dev.management_ip))

            mgmt_gw_ipif = IPv4Interface(mgmtdomain.ipv4_gw)
            access_device_variables = {
                'mgmt_vlan_id':
                mgmtdomain.vlan,
                'mgmt_gw':
                str(mgmt_gw_ipif.ip),
                'mgmt_ipif':
                str(
                    IPv4Interface('{}/{}'.format(
                        mgmt_ip, mgmt_gw_ipif.network.prefixlen))),
                'mgmt_prefixlen':
                int(mgmt_gw_ipif.network.prefixlen),
                'interfaces': []
            }
            intfs = session.query(Interface).filter(
                Interface.device == dev).all()
            intf: Interface
            for intf in intfs:
                untagged_vlan = None
                tagged_vlan_list = []
                intfdata = None
                if intf.data:
                    if 'untagged_vlan' in intf.data:
                        untagged_vlan = resolve_vlanid(
                            intf.data['untagged_vlan'], settings['vxlans'])
                    if 'tagged_vlan_list' in intf.data:
                        tagged_vlan_list = resolve_vlanid_list(
                            intf.data['tagged_vlan_list'], settings['vxlans'])
                    intfdata = dict(intf.data)
                access_device_variables['interfaces'].append({
                    'name':
                    intf.name,
                    'ifclass':
                    intf.configtype.name,
                    'untagged_vlan':
                    untagged_vlan,
                    'tagged_vlan_list':
                    tagged_vlan_list,
                    'data':
                    intfdata
                })
            mlag_vars = get_mlag_vars(session, dev)
            device_variables = {
                **access_device_variables,
                **device_variables,
                **mlag_vars
            }
        elif devtype == DeviceType.DIST or devtype == DeviceType.CORE:
            asn = generate_asn(infra_ip)
            fabric_device_variables = {
                'mgmt_ipif': str(IPv4Interface('{}/32'.format(mgmt_ip))),
                'mgmt_prefixlen': 32,
                'infra_ipif': str(IPv4Interface('{}/32'.format(infra_ip))),
                'infra_ip': str(infra_ip),
                'interfaces': [],
                'bgp_ipv4_peers': [],
                'bgp_evpn_peers': [],
                'mgmtdomains': [],
                'asn': asn
            }
            ifname_peer_map = dev.get_linknet_localif_mapping(session)
            if 'interfaces' in settings and settings['interfaces']:
                for intf in settings['interfaces']:
                    try:
                        ifindexnum: int = Interface.interface_index_num(
                            intf['name'])
                    except ValueError as e:
                        ifindexnum: int = 0
                    if 'ifclass' in intf and intf['ifclass'] == 'downlink':
                        data = {}
                        if intf['name'] in ifname_peer_map:
                            data['description'] = ifname_peer_map[intf['name']]
                        fabric_device_variables['interfaces'].append({
                            'name':
                            intf['name'],
                            'ifclass':
                            intf['ifclass'],
                            'indexnum':
                            ifindexnum,
                            'data':
                            data
                        })
                    elif 'ifclass' in intf and intf['ifclass'] == 'custom':
                        fabric_device_variables['interfaces'].append({
                            'name':
                            intf['name'],
                            'ifclass':
                            intf['ifclass'],
                            'config':
                            intf['config'],
                            'indexnum':
                            ifindexnum
                        })
            for mgmtdom in cnaas_nms.db.helper.get_all_mgmtdomains(
                    session, hostname):
                fabric_device_variables['mgmtdomains'].append({
                    'id':
                    mgmtdom.id,
                    'ipv4_gw':
                    mgmtdom.ipv4_gw,
                    'vlan':
                    mgmtdom.vlan,
                    'description':
                    mgmtdom.description,
                    'esi_mac':
                    mgmtdom.esi_mac
                })
            # find fabric neighbors
            fabric_links = []
            for neighbor_d in dev.get_neighbors(session):
                if neighbor_d.device_type == DeviceType.DIST or neighbor_d.device_type == DeviceType.CORE:
                    # TODO: support multiple links to the same neighbor?
                    local_if = dev.get_neighbor_local_ifname(
                        session, neighbor_d)
                    local_ipif = dev.get_neighbor_local_ipif(
                        session, neighbor_d)
                    neighbor_ip = dev.get_neighbor_ip(session, neighbor_d)
                    if local_if:
                        fabric_device_variables['interfaces'].append({
                            'name':
                            local_if,
                            'ifclass':
                            'fabric',
                            'ipv4if':
                            local_ipif,
                            'peer_hostname':
                            neighbor_d.hostname,
                            'peer_infra_lo':
                            str(neighbor_d.infra_ip),
                            'peer_ip':
                            str(neighbor_ip),
                            'peer_asn':
                            generate_asn(neighbor_d.infra_ip)
                        })
                        fabric_device_variables['bgp_ipv4_peers'].append({
                            'peer_hostname':
                            neighbor_d.hostname,
                            'peer_infra_lo':
                            str(neighbor_d.infra_ip),
                            'peer_ip':
                            str(neighbor_ip),
                            'peer_asn':
                            generate_asn(neighbor_d.infra_ip)
                        })
            # populate evpn peers data
            for neighbor_d in get_evpn_spines(session, settings):
                if neighbor_d.hostname == dev.hostname:
                    continue
                fabric_device_variables['bgp_evpn_peers'].append({
                    'peer_hostname':
                    neighbor_d.hostname,
                    'peer_infra_lo':
                    str(neighbor_d.infra_ip),
                    'peer_asn':
                    generate_asn(neighbor_d.infra_ip)
                })
            device_variables = {**fabric_device_variables, **device_variables}

    # Add all environment variables starting with TEMPLATE_SECRET_ to
    # the list of configuration variables. The idea is to store secret
    # configuration outside of the templates repository.
    template_secrets = {}
    for env in os.environ:
        if env.startswith('TEMPLATE_SECRET_'):
            template_secrets[env] = os.environ[env]

    # Merge device variables with settings before sending to template rendering
    # Device variables override any names from settings, for example the
    # interfaces list from settings are replaced with an interface list from
    # device variables that contains more information
    template_vars = {**settings, **device_variables, **template_secrets}

    with open('/etc/cnaas-nms/repository.yml', 'r') as db_file:
        repo_config = yaml.safe_load(db_file)
        local_repo_path = repo_config['templates_local']

    mapfile = os.path.join(local_repo_path, platform, 'mapping.yml')
    if not os.path.isfile(mapfile):
        raise RepoStructureException(
            "File {} not found in template repo".format(mapfile))
    with open(mapfile, 'r') as f:
        mapping = yaml.safe_load(f)
        template = mapping[devtype.name]['entrypoint']

    logger.debug("Generate config for host: {}".format(task.host.name))
    r = task.run(task=text.template_file,
                 name="Generate device config",
                 template=template,
                 path=f"{local_repo_path}/{task.host.platform}",
                 **template_vars)

    # TODO: Handle template not found, variables not defined
    # jinja2.exceptions.UndefinedError

    task.host["config"] = r.result
    task.host["template_vars"] = template_vars

    if generate_only:
        task.host["change_score"] = 0
    else:
        logger.debug("Synchronize device config for host: {} ({}:{})".format(
            task.host.name, task.host.hostname, task.host.port))

        task.host.open_connection("napalm", configuration=task.nornir.config)
        task.run(task=networking.napalm_configure,
                 name="Sync device config",
                 replace=True,
                 configuration=task.host["config"],
                 dry_run=dry_run)
        task.host.close_connection("napalm")

        if task.results[1].diff:
            config = task.results[1].host["config"]
            diff = task.results[1].diff
            task.host["change_score"] = calculate_score(config, diff)
        else:
            task.host["change_score"] = 0
    if job_id:
        with redis_session() as db:
            db.lpush('finished_devices_' + str(job_id), task.host.name)