예제 #1
0
파일: test_dapi.py 프로젝트: jaydave/wazuh
def test_DistributedAPI_local_request_errors():
    """Check the behaviour when the local_request function raised an error."""
    with patch(
            'wazuh.core.cluster.dapi.dapi.DistributedAPI.execute_local_request',
            new=AsyncMock(side_effect=WazuhInternalError(1001))):
        dapi_kwargs = {'f': agent.get_agents_summary_status, 'logger': logger}
        raise_if_exc_routine(dapi_kwargs=dapi_kwargs, expected_error=1001)

        dapi_kwargs['debug'] = True
        dapi = DistributedAPI(f=agent.get_agents_summary_status,
                              logger=logger,
                              debug=True)
        try:
            raise_if_exc(loop.run_until_complete(dapi.distribute_function()))
        except WazuhInternalError as e:
            assert e.code == 1001

    with patch(
            'wazuh.core.cluster.dapi.dapi.DistributedAPI.execute_local_request',
            new=AsyncMock(side_effect=KeyError('Testing'))):
        dapi_kwargs = {'f': agent.get_agents_summary_status, 'logger': logger}
        raise_if_exc_routine(dapi_kwargs=dapi_kwargs,
                             expected_error=1000)  # Specify KeyError

        dapi = DistributedAPI(f=agent.get_agents_summary_status,
                              logger=logger,
                              debug=True)
        try:
            raise_if_exc(loop.run_until_complete(dapi.distribute_function()))
        except KeyError as e:
            assert 'KeyError' in repr(e)
예제 #2
0
def decode_token(token):
    """Decode a jwt formatted token and add processed policies. Raise an Unauthorized exception in case validation fails.

    Parameters
    ----------
    token : str
        JWT formatted token

    Returns
    -------
    Dict payload ot the token
    """
    try:
        # Decode JWT token with local secret
        payload = jwt.decode(token,
                             generate_secret(),
                             algorithms=[JWT_ALGORITHM],
                             audience='Wazuh API REST')

        # Check token and add processed policies in the Master node
        dapi = DistributedAPI(f=check_token,
                              f_kwargs={
                                  'username': payload['sub'],
                                  'roles': payload['rbac_roles'],
                                  'token_nbf_time': payload['nbf'],
                                  'run_as': payload['run_as']
                              },
                              request_type='local_master',
                              is_async=False,
                              wait_for_complete=True,
                              logger=logging.getLogger('wazuh'))
        data = raise_if_exc(
            pool.submit(asyncio.run,
                        dapi.distribute_function()).result()).to_dict()

        if not data['result']['valid']:
            raise Unauthorized
        payload['rbac_policies'] = data['result']['policies']
        payload['rbac_policies']['rbac_mode'] = payload.pop('rbac_mode')

        # Detect local changes
        dapi = DistributedAPI(f=get_security_conf,
                              request_type='local_master',
                              is_async=False,
                              wait_for_complete=True,
                              logger=logging.getLogger('wazuh'))
        result = raise_if_exc(
            pool.submit(asyncio.run, dapi.distribute_function()).result())

        current_rbac_mode = result['rbac_mode']
        current_expiration_time = result['auth_token_exp_timeout']
        if payload['rbac_policies']['rbac_mode'] != current_rbac_mode \
                or (payload['exp'] - payload['nbf']) != current_expiration_time:
            raise Unauthorized

        return payload
    except JWTError as e:
        raise Unauthorized from e
예제 #3
0
def test_DistributedAPI_local_request(mock_local_request):
    """Test `local_request` method from class DistributedAPI and check the behaviour when an error raise."""
    dapi_kwargs = {'f': manager.status, 'logger': logger}
    raise_if_exc_routine(dapi_kwargs=dapi_kwargs)

    dapi_kwargs = {
        'f': cluster.get_nodes_info,
        'logger': logger,
        'local_client_arg': 'lc'
    }
    raise_if_exc_routine(dapi_kwargs=dapi_kwargs)

    dapi_kwargs['is_async'] = True
    raise_if_exc_routine(dapi_kwargs=dapi_kwargs)

    with patch('asyncio.wait_for',
               new=AsyncMock(side_effect=TimeoutError('Testing'))):
        dapi = DistributedAPI(f=manager.status, logger=logger)
        try:
            raise_if_exc(loop.run_until_complete(dapi.distribute_function()))
        except ProblemException as e:
            assert e.ext['dapi_errors'][list(e.ext['dapi_errors'].keys())[0]]['error'] == \
                   'Timeout executing API request'

    with patch('asyncio.wait_for',
               new=AsyncMock(side_effect=WazuhError(1001))):
        dapi_kwargs = {'f': manager.status, 'logger': logger}
        raise_if_exc_routine(dapi_kwargs=dapi_kwargs, expected_error=1001)

        dapi_kwargs['debug'] = True
        raise_if_exc_routine(dapi_kwargs=dapi_kwargs, expected_error=1001)

    with patch('asyncio.wait_for',
               new=AsyncMock(side_effect=WazuhInternalError(1001))):
        dapi_kwargs = {'f': manager.status, 'logger': logger}
        raise_if_exc_routine(dapi_kwargs=dapi_kwargs, expected_error=1001)

        dapi = DistributedAPI(f=manager.status, logger=logger, debug=True)
        try:
            raise_if_exc(loop.run_until_complete(dapi.distribute_function()))
        except WazuhInternalError as e:
            assert e.code == 1001

    with patch('asyncio.wait_for',
               new=AsyncMock(side_effect=KeyError('Testing'))):
        dapi_kwargs = {'f': manager.status, 'logger': logger}
        raise_if_exc_routine(dapi_kwargs=dapi_kwargs, expected_error=1000)

        dapi = DistributedAPI(f=manager.status, logger=logger, debug=True)
        try:
            raise_if_exc(loop.run_until_complete(dapi.distribute_function()))
        except Exception as e:
            assert type(e) == KeyError
예제 #4
0
def test_DistributedAPI_distribute_function_mock_solver(
        api_request, request_type, node, expected):
    """Test distribute_function functionality with unknown node.

    Parameters
    ----------
    api_request : callable
        Function to be executed
    request_type : str
        Request type (local_master, distributed_master, local_any)
    node : str
        Node type (Master and Workers)
    expected : str
        Expected result
    """
    with patch('wazuh.core.cluster.cluster.get_node',
               return_value={
                   'type': node,
                   'node': 'master'
               }):
        dapi = DistributedAPI(f=api_request,
                              logger=logger,
                              request_type=request_type,
                              from_cluster=False)
        data = raise_if_exc(loop.run_until_complete(
            dapi.distribute_function()))
        assert data.render()['result'] == expected
예제 #5
0
def test_DistributedAPI_distribute_function(api_request, request_type, node,
                                            expected, cluster_enabled):
    """Test distribute_function functionality with different test cases.

    Parameters
    ----------
    api_request : callable
        Function to be executed.
    request_type : str
        Request type (local_master, distributed_master, local_any).
    node : str
        Node type (Master and Workers).
    expected : str
        Expected result.
    cluster_enabled : bool
        Indicates whether cluster is enabled or not.
    """

    # Mock check_cluster_status and get_node
    with patch('wazuh.core.cluster.dapi.dapi.check_cluster_status',
               return_value=cluster_enabled):
        with patch('wazuh.core.cluster.cluster.get_node',
                   return_value={'type': node}):
            dapi = DistributedAPI(f=api_request,
                                  logger=logger,
                                  request_type=request_type)
            data = raise_if_exc(
                loop.run_until_complete(dapi.distribute_function()))
            assert data.render()['result'] == expected
예제 #6
0
def send_command(function, command, local_master=False):
    """Send the command to the specified function.
    If local_master is True, the request type must be local_master (upgrade_result)

    Parameters
    ----------
    function : func
        Upgrade function
    command : dict
        Arguments for the specified function
    local_master : bool
        True for get the upgrade results, False for send upgrade command

    Returns
    -------
    Distributed API request result
    """
    dapi = DistributedAPI(f=function,
                          f_kwargs=command,
                          request_type='distributed_master'
                          if not local_master else 'local_master',
                          is_async=False,
                          wait_for_complete=True,
                          logger=logger)
    pool = concurrent.futures.ThreadPoolExecutor()
    return raise_if_exc(pool.submit(run, dapi.distribute_function()).result())
예제 #7
0
def check_user(user, password, required_scopes=None):
    """Convenience method to use in OpenAPI specification

    Parameters
    ----------
    user : str
        Unique username
    password : str
        User password
    required_scopes

    Returns
    -------
    Dict with the username and his status
    """
    dapi = DistributedAPI(f=check_user_master,
                          f_kwargs={
                              'user': user,
                              'password': password
                          },
                          request_type='local_master',
                          is_async=False,
                          wait_for_complete=False,
                          logger=logging.getLogger('wazuh-api'))
    data = raise_if_exc(
        pool.submit(asyncio.run, dapi.distribute_function()).result())

    if data['result']:
        return {'sub': user, 'active': True}
예제 #8
0
def generate_token(user_id=None, data=None):
    """Generate an encoded jwt token. This method should be called once a user is properly logged on.

    Parameters
    ----------
    user_id : str
        Unique username
    data : dict
        Roles permissions for the user

    Returns
    -------
    JWT encode token
    """
    dapi = DistributedAPI(f=get_security_conf,
                          request_type='local_master',
                          is_async=False,
                          wait_for_complete=True,
                          logger=logging.getLogger('wazuh'))
    result = raise_if_exc(
        pool.submit(asyncio.run, dapi.distribute_function()).result()).dikt
    timestamp = int(time())

    payload = {
        "iss": JWT_ISSUER,
        "aud": "Wazuh API REST",
        "nbf": int(timestamp),
        "exp": int(timestamp + result['auth_token_exp_timeout']),
        "sub": str(user_id),
        "rbac_roles": data['roles'],
        "rbac_mode": result['rbac_mode']
    }

    return jwt.encode(payload, generate_secret(), algorithm=JWT_ALGORITHM)
예제 #9
0
def generate_token(user_id=None, rbac_policies=None):
    """Generate an encoded jwt token. This method should be called once a user is properly logged on.

    Parameters
    ----------
    user_id : str
        Unique username
    rbac_policies : dict
        Permissions for the user

    Returns
    -------
    JWT encode token
    """
    dapi = DistributedAPI(f=get_security_conf,
                          request_type='local_master',
                          is_async=False,
                          wait_for_complete=True,
                          logger=logging.getLogger('wazuh'))
    result = raise_if_exc(
        pool.submit(asyncio.run,
                    dapi.distribute_function()).result()).values()
    token_exp, rbac_mode = list(result)
    timestamp = int(time())
    rbac_policies['rbac_mode'] = rbac_mode
    payload = {
        "iss": JWT_ISSUER,
        "iat": int(timestamp),
        "exp": int(timestamp + token_exp),
        "sub": str(user_id),
        "rbac_policies": rbac_policies
    }

    return jwt.encode(payload, generate_secret(), algorithm=JWT_ALGORITHM)
예제 #10
0
def decode_token(token):
    """Decode a jwt formatted token. Raise an Unauthorized exception in case validation fails.

    Parameters
    ----------
    token : str
        JWT formatted token

    Returns
    -------
    Dict payload ot the token
    """
    try:
        payload = jwt.decode(token,
                             generate_secret(),
                             algorithms=[JWT_ALGORITHM])
        dapi = DistributedAPI(f=check_token,
                              f_kwargs={
                                  'username': payload['sub'],
                                  'token_iat_time': payload['iat']
                              },
                              request_type='local_master',
                              is_async=False,
                              wait_for_complete=True,
                              logger=logging.getLogger('wazuh'))
        data = raise_if_exc(
            pool.submit(asyncio.run, dapi.distribute_function()).result())

        if not data.to_dict()['result']['valid']:
            raise Unauthorized

        dapi = DistributedAPI(f=get_security_conf,
                              request_type='local_master',
                              is_async=False,
                              wait_for_complete=True,
                              logger=logging.getLogger('wazuh'))
        result = raise_if_exc(
            pool.submit(asyncio.run, dapi.distribute_function()).result())
        current_rbac_mode = result['rbac_mode']
        current_expiration_time = result['auth_token_exp_timeout']
        if payload['rbac_policies']['rbac_mode'] != current_rbac_mode or \
                (payload['exp'] - payload['iat']) != current_expiration_time:
            raise Unauthorized

        return payload
    except JWTError as e:
        raise Unauthorized from e
예제 #11
0
def raise_if_exc_routine(dapi_kwargs, expected_error=None):
    dapi = DistributedAPI(**dapi_kwargs)
    try:
        raise_if_exc(loop.run_until_complete(dapi.distribute_function()))
    except ProblemException as e:
        if expected_error:
            assert e.ext['code'] == expected_error
        else:
            assert False, f'Unexpected exception: {e.ext}'