def test_retry_condition(self): auth = ChannelAuth(self.url, self.username, self.password) with patch("requests.Session"): channel = Channel(self.url, auth=auth, consumer_group=self.consumer_group) self.assertFalse( channel._retry_if_not_consumer_error(ConsumerError())) self.assertTrue(channel._retry_if_not_consumer_error(Exception()))
def test_stop(self): auth = ChannelAuth(self.url, self.username, self.password) with patch("requests.Session") as session: session.return_value = MagicMock() # self._session session.return_value.request = MagicMock() def on_request(method, url, json=None): # pylint: disable=redefined-outer-name del method, json response_json = {} if url.endswith('/consumers'): response_json = {"consumerInstanceId": 1234} elif url.endswith('/records'): response_json = {"records": []} response_mock = MagicMock() response_mock.status_code = 200 response_mock.json = MagicMock(return_value=response_json) return response_mock session.return_value.request.side_effect = on_request channel = Channel(self.url, auth=auth, consumer_group=self.consumer_group, retry_on_fail=False) def on_consume(_): return True run_stopped = [False] def run_worker(): channel.run(on_consume, wait_between_queries=30, topics=["topic1", "topic2", "topic3"]) run_stopped[0] = True thread = threading.Thread(target=run_worker) thread.daemon = True thread.start() # Wait for the channel create, subscribe, and first consume # (records) call to be made while len(session.return_value.request.mock_calls) < 3: time.sleep(0.1) self.assertFalse(run_stopped[0]) channel.stop() thread.join() self.assertTrue(run_stopped[0]) session.return_value.request.assert_any_call( "post", "http://localhost/databus/consumer-service/v1/consumers/1234/subscription", json={"topics": ["topic1", "topic2", "topic3"]})
def activity_feed(self): logging.info("Starting event loop...") try: with Channel(self.url, auth=self.auth, consumer_group='mvisionedr_events', verify_cert_bundle='') as channel: def process_callback(payloads): if not payloads == []: for payload in payloads: if self.enrich == 'True': payload = self.epo_enrich(payload) print('Event received: {0}'.format( json.dumps(payload))) if args.module: self.run_modules(payload) return True channel.run(process_callback, wait_between_queries=5, topics=TOPICS) except Exception as e: exc_type, exc_obj, exc_tb = sys.exc_info() print( "ERROR: Error in {location}.{funct_name}() - line {line_no} : {error}" .format(location=__name__, funct_name=sys._getframe().f_code.co_name, line_no=exc_tb.tb_lineno, error=str(e)))
def main(): parser = setup_argument_parser() args = parser.parse_args() period = args.period loglevel = args.loglevel logging.basicConfig(level=getattr(logging, loglevel.upper(), None), stream=args.logfile) """ # Add SysLog support at global level # (TODO: Ask MDC as this "global" solution seems to be simpler than a custom level # sample which requires more code -As today): logger = logging.getLogger() handler = SysLogHandler(address='/dev/log') handler.setLevel(loglevel.upper()) logger.addHandler(handler) # Default log location in case needed at sudo level for investigating payloads (Normally commented): #logger.addHandler(logging.FileHandler("/var/log/mvedr_activity_feed.log")) """ sys.path.append(os.getcwd()) # load modules containing subscriptions for module in args.module: try: __import__(module) except Exception as exp: logging.critical("While attempting to load module '%s': %s", module, exp) exit(1) configs = get_config(args) logging.info("Sarting event loop...") try: with Channel(args.url, auth=ChannelAuth(args.url, args.username, args.password, verify_cert_bundle=args.cert_bundle), consumer_group=args.consumer_group, verify_cert_bundle=args.cert_bundle) as channel: def process_callback(payloads): print("Received payloads: \n%s", json.dumps(payloads, indent=4, sort_keys=True)) invoke(payloads, configs) return True # Consume records indefinitely channel.run(process_callback, wait_between_queries=period, topics=args.topic) except Exception as e: logging.error("Unexpected error: {}".format(e))
def main(): parser = setup_argument_parser() args = parser.parse_args() period = args.period loglevel = args.loglevel logging.basicConfig(level=getattr(logging, loglevel.upper(), None), stream=args.logfile) logger = logging.getLogger() ch = logging.StreamHandler() formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s") ch.setFormatter(formatter) logger.addHandler(ch) sys.path.append(os.getcwd()) # load modules containing subscriptions for module in args.module: try: __import__(module) except Exception as exp: logging.critical("While attempting to load module '%s': %s", module, exp) exit(1) configs = get_config(args) logging.info("Starting event loop...") try: with Channel(args.url, auth=ChannelAuth(args.url, args.username, args.password, verify_cert_bundle=args.cert_bundle), consumer_group=args.consumer_group, verify_cert_bundle=args.cert_bundle, offset='latest' if not args.consumer_reset else 'earliest') as channel: def process_callback(payloads): logging.debug("Received payloads: \n%s", json.dumps(payloads, indent=4, sort_keys=True)) invoke(payloads, configs) return True # Consume records indefinitely channel.run(process_callback, wait_between_queries=period, topics=args.topic) except Exception as e: logging.error("Unexpected error: {}".format(e))
def test_service(self): with fake_streaming_service.ConsumerService(0) as service: channel_url = furl(BASE_CHANNEL_URL).set(port=service.port) with Channel(channel_url, auth=ChannelAuth( channel_url, fake_streaming_service.AUTH_USER, fake_streaming_service.AUTH_PASSWORD), consumer_group=fake_streaming_service.CONSUMER_GROUP) \ as channel: channel.create() self.assertEqual(len(service._active_consumers), 1) topic = "case-mgmt-events" channel.subscribe(topic) expected_records = [] for record in fake_streaming_service.DEFAULT_RECORDS: if record['routingData']['topic'] == topic: expected_records.append( json.loads( base64.b64decode( record['message']['payload']).decode())) records_consumed = channel.consume() self.assertEqual(expected_records, records_consumed) channel.commit() self.assertEqual([], channel.consume()) message_payload = {"detail": "Hello from OpenDXL"} produce_payload = { "records": [{ "routingData": { "topic": topic, "shardingKey": "" }, "message": { "headers": {}, "payload": base64.b64encode( json.dumps(message_payload).encode()).decode() } }] } channel.produce(produce_payload) records_consumed = channel.consume() expected_records = [message_payload] self.assertEqual(expected_records, records_consumed) self.assertEqual(len(service._active_consumers), 0)
CHANNEL_TOPIC_SUBSCRIPTIONS = ["case-mgmt-events", "my-topic", "topic-abc123"] # Path to a CA bundle file containing certificates of trusted CAs. The CA # bundle is used to validate that the certificate of the server being connected # to was signed by a valid authority. If set to an empty string, the server # certificate is not validated. VERIFY_CERTIFICATE_BUNDLE = "" # This constant controls the frequency (in seconds) at which the channel 'run' # call below polls the streaming service for new records. WAIT_BETWEEN_QUERIES = 5 # Create a new channel object with Channel(CHANNEL_URL, auth=ChannelAuth(CHANNEL_URL, CHANNEL_USERNAME, CHANNEL_PASSWORD, verify_cert_bundle=VERIFY_CERTIFICATE_BUNDLE), consumer_group=CHANNEL_CONSUMER_GROUP, verify_cert_bundle=VERIFY_CERTIFICATE_BUNDLE) as channel: # Create a function which will be called back upon by the 'run' method (see # below) when records are received from the channel. def process_callback(payloads): # Print the payloads which were received. 'payloads' is a list of # dictionary objects extracted from the records received from the # channel. logger.info("Received payloads: \n%s", json.dumps(payloads, indent=4, sort_keys=True)) # Return 'True' in order for the 'run' call to continue attempting to # consume records. return True
def main(): parser = setup_argument_parser() args = parser.parse_args() if args.username and args.client_id: logging.critical( "Use only one of the authentication credentials, either username/password or client_id/client_secret" ) exit(1) if not args.username: if not args.client_id: logging.critical( "Missing the authentication credentials. Use either username/password or client_id/client_secret" ) exit(1) if not args.client_secret: args.client_secret = getpass.getpass( prompt='MVISION EDR Client Secret: ') if not args.client_secret and not args.password: args.password = getpass.getpass(prompt='MVISION EDR Password: '******'%s': %s", module, exp) exit(1) configs = get_config(args) CHANNEL_SCOPE = "soc.hts.c soc.hts.r soc.rts.c soc.rts.r soc.qry.pr soc.skr.pr soc.evt.vi soc.cop.r dxls.evt.w dxls.evt.r" CHANNEL_GRANT_TYPE = "client_credentials" CHANNEL_AUDIENCE = "mcafee" CHANNEL_IAM_URL = 'https://iam.mcafee-cloud.com/' if args.preprod: CHANNEL_IAM_URL = 'https://preprod.iam.mcafee-cloud.com/' logging.info("Consumer Group={a}".format(a=args.consumer_group)) logging.info("Topics={a}".format(a=args.topic)) logging.info("Starting event loop...") try: if args.username: with Channel(args.url, auth=ChannelAuth(args.url, args.username, args.password, verify_cert_bundle=args.cert_bundle), consumer_group=args.consumer_group, verify_cert_bundle=args.cert_bundle, offset='latest' if not args.consumer_reset else 'earliest') as channel: def process_callback(payloads): logging.debug( "Received payloads: \n%s", json.dumps(payloads, indent=4, sort_keys=True)) invoke(payloads, configs) return True # Consume records indefinitely channel.run(process_callback, wait_between_queries=period, topics=args.topic) if args.client_id: with Channel(args.url, auth=ClientCredentialsChannelAuth( CHANNEL_IAM_URL, args.client_id, args.client_secret, verify_cert_bundle=args.cert_bundle, scope=CHANNEL_SCOPE, grant_type=CHANNEL_GRANT_TYPE, audience=CHANNEL_AUDIENCE), consumer_group=args.consumer_group, verify_cert_bundle=args.cert_bundle, offset='latest' if not args.consumer_reset else 'earliest') as channel: def process_callback(payloads): logging.debug( "Received payloads: \n%s", json.dumps(payloads, indent=4, sort_keys=True)) invoke(payloads, configs) return True # Consume records indefinitely channel.run(process_callback, wait_between_queries=period, topics=args.topic) except Exception as e: logging.error("Unexpected error: {}".format(e))
def main(): parser = setup_argument_parser() args = parser.parse_args() if args.username and args.client_id: logging.critical( "Use only one of the authentication credentials, either username/password or client_id/client_secret" ) exit(1) if not args.username: if not args.client_id: logging.critical( "Missing the authentication credentials. Use either username/password or client_id/client_secret" ) exit(1) if not args.client_secret: args.client_secret = getpass.getpass( prompt='MVISION EDR Client Secret: ') if not args.client_secret and not args.password: args.password = getpass.getpass(prompt='MVISION EDR Password: '******'https://iam.mcafee-cloud.com/' if args.preprod: CHANNEL_IAM_URL = 'https://preprod.iam.mcafee-cloud.com/' root_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(root_dir + "/../..") sys.path.append(root_dir + "/..") # Configure local logger logging.getLogger().setLevel(logging.INFO) logger = logging.getLogger(__name__) # Create the message payload to be included in a record message_payload = {"message": "Hello from Activity Feed"} # Create the full payload with records to produce to the channel channel_payload = { "records": [{ "routingData": { "topic": args.topic + "-" + args.tenant_id, "shardingKey": "" }, "message": { "headers": {}, # Convert the message payload from a dictionary to a # base64-encoded string. "payload": base64.b64encode( json.dumps(message_payload).encode()).decode() } }] } # Create a new channel object with username/password if args.username: with Channel(args.url, auth=ChannelAuth(args.url, args.username, args.password, verify_cert_bundle=args.cert_bundle), consumer_group=args.consumer_group, verify_cert_bundle=args.cert_bundle) as channel: # Produce the payload records to the channel channel.produce(channel_payload) # Create a new channel object with client_id/client_secret if args.client_id: with Channel(args.url, auth=ClientCredentialsChannelAuth( CHANNEL_IAM_URL, args.client_id, args.client_secret, verify_cert_bundle=args.cert_bundle, scope=CHANNEL_SCOPE, grant_type=CHANNEL_GRANT_TYPE, audience=CHANNEL_AUDIENCE), consumer_group=args.consumer_group, verify_cert_bundle=args.cert_bundle) as channel: # Produce the payload records to the channel channel.produce(channel_payload) print("Succeeded.")
def test_channel_auth(self): auth = ChannelAuth(self.url, self.username, self.password) req = MagicMock() req.headers = {} with patch("requests.get") as req_get: req_get.return_value = MagicMock() req_get.return_value.status_code = 200 original_token = "1234567890" req_get.return_value.json = MagicMock( return_value={"AuthorizationToken": original_token}) req = auth(req) self.assertIsNotNone(req) self.assertEqual(req.headers["Authorization"], "Bearer {}".format(original_token)) new_token = "ABCDEFGHIJ" req_get.return_value.json = MagicMock( return_value={"AuthorizationToken": new_token}) # Even though the token that would be returned for a login attempt # has changed, the original token should be returned because it # was cached on the auth object. req = auth(req) self.assertIsNotNone(req) self.assertEqual(req.headers["Authorization"], "Bearer {}".format(original_token)) res = MagicMock() res.status_code = 403 res.request.headers = {} with patch("requests.Session") as session: channel = Channel(self.url, auth=auth, consumer_group=self.consumer_group) create_403_mock = MagicMock() create_403_mock.status_code = 403 create_200_mock = MagicMock() create_200_mock.status_code = 200 create_200_mock.json = MagicMock( return_value={"consumerInstanceId": 1234}, ) self.assertIsNone(channel._consumer_id) self.assertEqual(auth._token, original_token) session.return_value.request.side_effect = [ create_403_mock, create_200_mock ] channel.create() self.assertEqual(channel._consumer_id, 1234) # The 403 returned from the channel create call above should # lead to a new token being issued for the next authentication # call. req = auth(req) self.assertIsNotNone(req) self.assertEqual(req.headers["Authorization"], "Bearer {}".format(new_token)) self.assertEqual(auth._token, new_token)
def test_run(self): auth = ChannelAuth(self.url, self.username, self.password) record_1_payload = {"testing": "record_1"} record_2_payload = {"testing": "record_2"} record_1 = create_record("topic1", record_1_payload, partition=1, offset=1) record_2 = create_record("topic2", record_2_payload, partition=1, offset=2) first_records_group = create_records([record_1, record_2]) record_3_payload = {"testing": "record_3"} record_3 = create_record("topic3", record_3_payload, partition=2, offset=3) second_records_group = create_records([record_3]) third_records_group = create_records([]) expected_payloads_received = [[record_1_payload, record_2_payload], [record_3_payload], []] expected_calls = [ call("post", "http://localhost/databus/consumer-service/v1/consumers", json={ "consumerGroup": self.consumer_group, "configs": { "auto.offset.reset": "latest", "enable.auto.commit": "false" } }), call( "post", "http://localhost/databus/consumer-service/v1/consumers/1234/subscription", json={"topics": ["topic1", "topic2", "topic3"]}), call( "get", "http://localhost/databus/consumer-service/v1/consumers/1234/records" ), call( "post", "http://localhost/databus/consumer-service/v1/consumers/1234/offsets", json={ "offsets": [{ "topic": "topic1", "partition": 1, "offset": 1 }, { "topic": "topic2", "partition": 1, "offset": 2 }] }), call( "get", "http://localhost/databus/consumer-service/v1/consumers/1234/records" ), call( "post", "http://localhost/databus/consumer-service/v1/consumers/1234/offsets", json={ "offsets": [{ "topic": "topic3", "partition": 2, "offset": 3 }] }), call( "get", "http://localhost/databus/consumer-service/v1/consumers/1234/records" ), call("post", "http://localhost/databus/consumer-service/v1/consumers", json={ "consumerGroup": self.consumer_group, "configs": { "auto.offset.reset": "latest", "enable.auto.commit": "false" } }), call( "post", "http://localhost/databus/consumer-service/v1/consumers/5678/subscription", json={"topics": ["topic1", "topic2", "topic3"]}), call( "get", "http://localhost/databus/consumer-service/v1/consumers/5678/records" ) ] with patch("requests.Session") as session: session.return_value = MagicMock() # self._session session.return_value.request = MagicMock() create_consumer_1_mock = MagicMock() create_consumer_1_mock.status_code = 200 create_consumer_1_mock.json = MagicMock( return_value={"consumerInstanceId": 1234}) subscr_mock = MagicMock() subscr_mock.status_code = 204 consume_1_mock = MagicMock() consume_1_mock.status_code = 200 consume_1_mock.json = MagicMock(return_value=first_records_group) consume_2_mock = MagicMock() consume_2_mock.status_code = 200 consume_2_mock.json = MagicMock(return_value=second_records_group) consume_not_found_mock = MagicMock() consume_not_found_mock.status_code = 404 create_consumer_2_mock = MagicMock() create_consumer_2_mock.status_code = 200 create_consumer_2_mock.json = MagicMock( return_value={"consumerInstanceId": 5678}) consume_3_mock = MagicMock() consume_3_mock.status_code = 200 consume_3_mock.json = MagicMock(return_value=third_records_group) commit_mock = MagicMock() commit_mock.status_code = 204 session.return_value.request.side_effect = [ create_consumer_1_mock, subscr_mock, consume_1_mock, commit_mock, consume_2_mock, commit_mock, consume_not_found_mock, create_consumer_2_mock, subscr_mock, consume_3_mock, commit_mock ] channel = Channel(self.url, auth=auth, consumer_group=self.consumer_group, retry_on_fail=True) payloads_received = [] def on_consume(payloads): payloads_received.append(payloads) # Return True (continue consuming) only if at least one # payload dictionary was supplied in the payloads parameter. # Return False to terminate the run call # when no additional payloads are available to consume. return len(payloads) > 0 session.return_value.request.reset_mock() channel.run(on_consume, wait_between_queries=0, topics=["topic1", "topic2", "topic3"]) session.return_value.request.assert_has_calls(expected_calls) self.assertEqual(payloads_received, expected_payloads_received) self.assertEqual(len(payloads_received), 3)
def test_path_prefix(self): auth = ChannelAuth(self.url, self.username, self.password) with patch("requests.Session") as session: session.return_value = MagicMock() # self._session session.return_value.request = MagicMock() create_mock = MagicMock() create_mock.status_code = 200 create_mock.json = MagicMock( return_value={"consumerInstanceId": 1234}) produce_mock = MagicMock() produce_mock.status_code = 204 session.return_value.request.side_effect = [ create_mock, produce_mock ] channel = Channel(self.url, auth=auth, consumer_group=self.consumer_group, path_prefix="/base-path", retry_on_fail=False) channel.create() session.return_value.request.assert_called_with( "post", "http://localhost/base-path/consumers", json={ "consumerGroup": self.consumer_group, "configs": { "auto.offset.reset": "latest", "enable.auto.commit": "false" } }) channel.produce({}) session.return_value.request.assert_called_with( "post", "http://localhost/base-path/produce", json={}, headers={"Content-Type": _PRODUCE_CONTENT_TYPE}) session.return_value.request.reset_mock() session.return_value.request.side_effect = [ create_mock, produce_mock ] channel = Channel(self.url, auth=auth, consumer_group=self.consumer_group, consumer_path_prefix="/custom-consumer-path", producer_path_prefix="/custom-producer-path", retry_on_fail=False) channel.create() session.return_value.request.assert_called_with( "post", "http://localhost/custom-consumer-path/consumers", json={ "consumerGroup": self.consumer_group, "configs": { "auto.offset.reset": "latest", "enable.auto.commit": "false" } }) channel.produce({}) session.return_value.request.assert_called_with( "post", "http://localhost/custom-producer-path/produce", json={}, headers={"Content-Type": _PRODUCE_CONTENT_TYPE})
def test_main(self): auth = ChannelAuth(self.url, self.username, self.password) case_event = { "id": "a45a03de-5c3d-452a-8a37-f68be954e784", "entity": "case", "type": "creation", "tenant-id": "7af4746a-63be-45d8-9fb5-5f58bf909c25", "user": "******", "origin": "", "nature": "", "timestamp": "", "transaction-id": "", "case": { "id": "c00547df-6d74-4833-95ad-3a377c7274a6", "name": "A great case full of malware", "url": "https://mycaseserver.com/#/cases" "/4e8e23f4-9fe9-4215-92c9-12c9672be9f1", "priority": "Low" } } with patch("requests.Session") as session: session.return_value = MagicMock() # self._session session.return_value.request = MagicMock() create_mock = MagicMock() create_mock.status_code = 200 create_mock.json = MagicMock( return_value={"consumerInstanceId": 1234}) subscr_mock = MagicMock() subscr_mock.status_code = 204 consum_mock = MagicMock() consum_mock.status_code = 200 consum_mock.json = MagicMock(return_value=create_records([ create_record("foo-topic", case_event, partition=1, offset=1) ])) commit_consumer_error_mock = MagicMock() commit_consumer_error_mock.status_code = 404 commit_error_mock = MagicMock() commit_error_mock.status_code = 500 commit_mock = MagicMock() commit_mock.status_code = 204 produce_mock = MagicMock() produce_mock.status_code = 204 delete_mock = MagicMock() delete_mock.status_code = 204 delete_404_mock = MagicMock() delete_404_mock.status_code = 404 delete_500_mock = MagicMock() delete_500_mock.status_code = 500 session.return_value.request.side_effect = [ create_mock, subscr_mock, consum_mock, commit_consumer_error_mock, commit_error_mock, commit_mock, produce_mock, delete_500_mock, delete_404_mock, delete_mock ] channel = Channel(self.url, auth=auth, consumer_group=self.consumer_group, retry_on_fail=False, verify_cert_bundle="cabundle.crt", request_timeout=70, session_timeout=60, offset="earliest", extra_configs={ "enable.auto.commit": "true", "one.extra.setting": "one extra value", "another.extra.setting": 42 }) self.assertEqual(channel._session.verify, "cabundle.crt") channel.commit() # forcing early exit due to no records to commit channel.create() session.return_value.request.assert_called_with( "post", "http://localhost/databus/consumer-service/v1/consumers", json={ "consumerGroup": self.consumer_group, "configs": { "request.timeout.ms": "70000", "session.timeout.ms": "60000", "enable.auto.commit": "true", "auto.offset.reset": "earliest", "one.extra.setting": "one extra value", "another.extra.setting": 42 } }) channel.subscribe(["topic1", "topic2"]) session.return_value.request.assert_called_with( "post", "http://localhost/databus/consumer-service/v1/consumers/1234/subscription", json={"topics": ["topic1", "topic2"]}) records = channel.consume() self.assertEqual(records[0]["id"], "a45a03de-5c3d-452a-8a37-f68be954e784") with self.assertRaises(ConsumerError): channel.commit() with self.assertRaises(TemporaryError): channel.commit() channel.commit() message_payload = {"detail": "Hello from OpenDXL"} produce_payload = { "records": [{ "routingData": { "topic": "topic1", "shardingKey": "" }, "message": { "headers": {}, "payload": base64.b64encode( json.dumps(message_payload).encode()).decode() } }] } channel.produce(produce_payload) session.return_value.request.assert_called_with( "post", "http://localhost/databus/cloudproxy/v1/produce", json=produce_payload, headers={"Content-Type": _PRODUCE_CONTENT_TYPE}) with self.assertRaises(TemporaryError): channel.delete() # trigger 500 session.return_value.request.assert_called_with( "delete", "http://localhost/databus/consumer-service/v1/consumers/1234" ) session.return_value.request.reset_mock() channel.delete() # trigger silent 404 session.return_value.request.assert_called_with( "delete", "http://localhost/databus/consumer-service/v1/consumers/1234") session.return_value.request.reset_mock() channel._consumer_id = "1234" # resetting consumer channel.delete() # Proper deletion session.return_value.request.assert_called_with( "delete", "http://localhost/databus/consumer-service/v1/consumers/1234") session.return_value.request.reset_mock() channel.delete() # trigger early exit
message_payload = {"message": "Hello from OpenDXL"} # Create the full payload with records to produce to the channel channel_payload = { "records": [{ "routingData": { "topic": CHANNEL_TOPIC, "shardingKey": "" }, "message": { "headers": {}, # Convert the message payload from a dictionary to a # base64-encoded string. "payload": base64.b64encode(json.dumps(message_payload).encode()).decode() } }] } # Create a new channel object with Channel(CHANNEL_URL, auth=ChannelAuth(CHANNEL_URL, CHANNEL_USERNAME, CHANNEL_PASSWORD, verify_cert_bundle=VERIFY_CERTIFICATE_BUNDLE), verify_cert_bundle=VERIFY_CERTIFICATE_BUNDLE) as channel: # Produce the payload records to the channel channel.produce(channel_payload) print("Succeeded.")