def test_DistributedAPI_local_request_errors(): """Check the behaviour when the local_request function raised an error.""" with patch( 'wazuh.core.cluster.dapi.dapi.DistributedAPI.execute_local_request', new=AsyncMock(side_effect=WazuhInternalError(1001))): dapi_kwargs = {'f': agent.get_agents_summary_status, 'logger': logger} raise_if_exc_routine(dapi_kwargs=dapi_kwargs, expected_error=1001) dapi_kwargs['debug'] = True dapi = DistributedAPI(f=agent.get_agents_summary_status, logger=logger, debug=True) try: raise_if_exc(loop.run_until_complete(dapi.distribute_function())) except WazuhInternalError as e: assert e.code == 1001 with patch( 'wazuh.core.cluster.dapi.dapi.DistributedAPI.execute_local_request', new=AsyncMock(side_effect=KeyError('Testing'))): dapi_kwargs = {'f': agent.get_agents_summary_status, 'logger': logger} raise_if_exc_routine(dapi_kwargs=dapi_kwargs, expected_error=1000) # Specify KeyError dapi = DistributedAPI(f=agent.get_agents_summary_status, logger=logger, debug=True) try: raise_if_exc(loop.run_until_complete(dapi.distribute_function())) except KeyError as e: assert 'KeyError' in repr(e)
def decode_token(token): """Decode a jwt formatted token and add processed policies. Raise an Unauthorized exception in case validation fails. Parameters ---------- token : str JWT formatted token Returns ------- Dict payload ot the token """ try: # Decode JWT token with local secret payload = jwt.decode(token, generate_secret(), algorithms=[JWT_ALGORITHM], audience='Wazuh API REST') # Check token and add processed policies in the Master node dapi = DistributedAPI(f=check_token, f_kwargs={ 'username': payload['sub'], 'roles': payload['rbac_roles'], 'token_nbf_time': payload['nbf'], 'run_as': payload['run_as'] }, request_type='local_master', is_async=False, wait_for_complete=True, logger=logging.getLogger('wazuh')) data = raise_if_exc( pool.submit(asyncio.run, dapi.distribute_function()).result()).to_dict() if not data['result']['valid']: raise Unauthorized payload['rbac_policies'] = data['result']['policies'] payload['rbac_policies']['rbac_mode'] = payload.pop('rbac_mode') # Detect local changes dapi = DistributedAPI(f=get_security_conf, request_type='local_master', is_async=False, wait_for_complete=True, logger=logging.getLogger('wazuh')) result = raise_if_exc( pool.submit(asyncio.run, dapi.distribute_function()).result()) current_rbac_mode = result['rbac_mode'] current_expiration_time = result['auth_token_exp_timeout'] if payload['rbac_policies']['rbac_mode'] != current_rbac_mode \ or (payload['exp'] - payload['nbf']) != current_expiration_time: raise Unauthorized return payload except JWTError as e: raise Unauthorized from e
def test_DistributedAPI_local_request(mock_local_request): """Test `local_request` method from class DistributedAPI and check the behaviour when an error raise.""" dapi_kwargs = {'f': manager.status, 'logger': logger} raise_if_exc_routine(dapi_kwargs=dapi_kwargs) dapi_kwargs = { 'f': cluster.get_nodes_info, 'logger': logger, 'local_client_arg': 'lc' } raise_if_exc_routine(dapi_kwargs=dapi_kwargs) dapi_kwargs['is_async'] = True raise_if_exc_routine(dapi_kwargs=dapi_kwargs) with patch('asyncio.wait_for', new=AsyncMock(side_effect=TimeoutError('Testing'))): dapi = DistributedAPI(f=manager.status, logger=logger) try: raise_if_exc(loop.run_until_complete(dapi.distribute_function())) except ProblemException as e: assert e.ext['dapi_errors'][list(e.ext['dapi_errors'].keys())[0]]['error'] == \ 'Timeout executing API request' with patch('asyncio.wait_for', new=AsyncMock(side_effect=WazuhError(1001))): dapi_kwargs = {'f': manager.status, 'logger': logger} raise_if_exc_routine(dapi_kwargs=dapi_kwargs, expected_error=1001) dapi_kwargs['debug'] = True raise_if_exc_routine(dapi_kwargs=dapi_kwargs, expected_error=1001) with patch('asyncio.wait_for', new=AsyncMock(side_effect=WazuhInternalError(1001))): dapi_kwargs = {'f': manager.status, 'logger': logger} raise_if_exc_routine(dapi_kwargs=dapi_kwargs, expected_error=1001) dapi = DistributedAPI(f=manager.status, logger=logger, debug=True) try: raise_if_exc(loop.run_until_complete(dapi.distribute_function())) except WazuhInternalError as e: assert e.code == 1001 with patch('asyncio.wait_for', new=AsyncMock(side_effect=KeyError('Testing'))): dapi_kwargs = {'f': manager.status, 'logger': logger} raise_if_exc_routine(dapi_kwargs=dapi_kwargs, expected_error=1000) dapi = DistributedAPI(f=manager.status, logger=logger, debug=True) try: raise_if_exc(loop.run_until_complete(dapi.distribute_function())) except Exception as e: assert type(e) == KeyError
def test_DistributedAPI_distribute_function_mock_solver( api_request, request_type, node, expected): """Test distribute_function functionality with unknown node. Parameters ---------- api_request : callable Function to be executed request_type : str Request type (local_master, distributed_master, local_any) node : str Node type (Master and Workers) expected : str Expected result """ with patch('wazuh.core.cluster.cluster.get_node', return_value={ 'type': node, 'node': 'master' }): dapi = DistributedAPI(f=api_request, logger=logger, request_type=request_type, from_cluster=False) data = raise_if_exc(loop.run_until_complete( dapi.distribute_function())) assert data.render()['result'] == expected
def test_DistributedAPI_distribute_function(api_request, request_type, node, expected, cluster_enabled): """Test distribute_function functionality with different test cases. Parameters ---------- api_request : callable Function to be executed. request_type : str Request type (local_master, distributed_master, local_any). node : str Node type (Master and Workers). expected : str Expected result. cluster_enabled : bool Indicates whether cluster is enabled or not. """ # Mock check_cluster_status and get_node with patch('wazuh.core.cluster.dapi.dapi.check_cluster_status', return_value=cluster_enabled): with patch('wazuh.core.cluster.cluster.get_node', return_value={'type': node}): dapi = DistributedAPI(f=api_request, logger=logger, request_type=request_type) data = raise_if_exc( loop.run_until_complete(dapi.distribute_function())) assert data.render()['result'] == expected
def send_command(function, command, local_master=False): """Send the command to the specified function. If local_master is True, the request type must be local_master (upgrade_result) Parameters ---------- function : func Upgrade function command : dict Arguments for the specified function local_master : bool True for get the upgrade results, False for send upgrade command Returns ------- Distributed API request result """ dapi = DistributedAPI(f=function, f_kwargs=command, request_type='distributed_master' if not local_master else 'local_master', is_async=False, wait_for_complete=True, logger=logger) pool = concurrent.futures.ThreadPoolExecutor() return raise_if_exc(pool.submit(run, dapi.distribute_function()).result())
def check_user(user, password, required_scopes=None): """Convenience method to use in OpenAPI specification Parameters ---------- user : str Unique username password : str User password required_scopes Returns ------- Dict with the username and his status """ dapi = DistributedAPI(f=check_user_master, f_kwargs={ 'user': user, 'password': password }, request_type='local_master', is_async=False, wait_for_complete=False, logger=logging.getLogger('wazuh-api')) data = raise_if_exc( pool.submit(asyncio.run, dapi.distribute_function()).result()) if data['result']: return {'sub': user, 'active': True}
def generate_token(user_id=None, data=None): """Generate an encoded jwt token. This method should be called once a user is properly logged on. Parameters ---------- user_id : str Unique username data : dict Roles permissions for the user Returns ------- JWT encode token """ dapi = DistributedAPI(f=get_security_conf, request_type='local_master', is_async=False, wait_for_complete=True, logger=logging.getLogger('wazuh')) result = raise_if_exc( pool.submit(asyncio.run, dapi.distribute_function()).result()).dikt timestamp = int(time()) payload = { "iss": JWT_ISSUER, "aud": "Wazuh API REST", "nbf": int(timestamp), "exp": int(timestamp + result['auth_token_exp_timeout']), "sub": str(user_id), "rbac_roles": data['roles'], "rbac_mode": result['rbac_mode'] } return jwt.encode(payload, generate_secret(), algorithm=JWT_ALGORITHM)
def generate_token(user_id=None, rbac_policies=None): """Generate an encoded jwt token. This method should be called once a user is properly logged on. Parameters ---------- user_id : str Unique username rbac_policies : dict Permissions for the user Returns ------- JWT encode token """ dapi = DistributedAPI(f=get_security_conf, request_type='local_master', is_async=False, wait_for_complete=True, logger=logging.getLogger('wazuh')) result = raise_if_exc( pool.submit(asyncio.run, dapi.distribute_function()).result()).values() token_exp, rbac_mode = list(result) timestamp = int(time()) rbac_policies['rbac_mode'] = rbac_mode payload = { "iss": JWT_ISSUER, "iat": int(timestamp), "exp": int(timestamp + token_exp), "sub": str(user_id), "rbac_policies": rbac_policies } return jwt.encode(payload, generate_secret(), algorithm=JWT_ALGORITHM)
def decode_token(token): """Decode a jwt formatted token. Raise an Unauthorized exception in case validation fails. Parameters ---------- token : str JWT formatted token Returns ------- Dict payload ot the token """ try: payload = jwt.decode(token, generate_secret(), algorithms=[JWT_ALGORITHM]) dapi = DistributedAPI(f=check_token, f_kwargs={ 'username': payload['sub'], 'token_iat_time': payload['iat'] }, request_type='local_master', is_async=False, wait_for_complete=True, logger=logging.getLogger('wazuh')) data = raise_if_exc( pool.submit(asyncio.run, dapi.distribute_function()).result()) if not data.to_dict()['result']['valid']: raise Unauthorized dapi = DistributedAPI(f=get_security_conf, request_type='local_master', is_async=False, wait_for_complete=True, logger=logging.getLogger('wazuh')) result = raise_if_exc( pool.submit(asyncio.run, dapi.distribute_function()).result()) current_rbac_mode = result['rbac_mode'] current_expiration_time = result['auth_token_exp_timeout'] if payload['rbac_policies']['rbac_mode'] != current_rbac_mode or \ (payload['exp'] - payload['iat']) != current_expiration_time: raise Unauthorized return payload except JWTError as e: raise Unauthorized from e
def raise_if_exc_routine(dapi_kwargs, expected_error=None): dapi = DistributedAPI(**dapi_kwargs) try: raise_if_exc(loop.run_until_complete(dapi.distribute_function())) except ProblemException as e: if expected_error: assert e.ext['code'] == expected_error else: assert False, f'Unexpected exception: {e.ext}'