def output_exists(config, props, service): """Determine if this service and destination combo has already been created Args: config (dict): The outputs config that has been read from disk props (OrderedDict): Contains various OutputProperty items service (str): The service for which the user is adding a configuration Returns: [boolean] True if the service/destination exists already """ if service in config and props['descriptor'].value in config[service]: LOGGER_CLI.error( 'This descriptor is already configured for %s. ' 'Please select a new and unique descriptor', service) return True return False
def put_mocked_s3_object(bucket_name, key_name, body_value): """Create a mock AWS S3 object for testing Args: bucket_name: the bucket in which to place the object (string) key_name: the key to use for the S3 object (string) body_value: the actual value to use for the object (string) """ s3_resource = boto3.resource('s3', region_name='us-east-1') s3_resource.create_bucket(Bucket=bucket_name) obj = s3_resource.Object(bucket_name, key_name) response = obj.put(Body=body_value) # Log if this was not a success (this should not fail for mocked objects) if response['ResponseMetadata']['HTTPStatusCode'] != 200: LOGGER_CLI.error('Could not put mock object with key %s in s3 bucket with name %s', key_name, bucket_name)
def load_outputs_config(conf_dir='conf'): """Load the outputs configuration file from disk Args: conf_dir (str): Directory to read outputs config from Returns: dict: The output configuration settings """ with open(os.path.join(conf_dir, OUTPUTS_CONFIG)) as outputs: try: values = json.load(outputs) except ValueError: LOGGER_CLI.error('The %s file could not be loaded into json', OUTPUTS_CONFIG) raise return values
def save_api_creds_info(region, overwrite=False): """Function to add API creds information to parameter store Args: info (dict): Required values needed to save the requested credentials information to AWS Parameter Store """ # Get all of the required credentials from the user for API calls required_creds = { 'api_user': { 'description': ('API username to retrieve IOCs via API calls. ' 'This should be an email address.'), 'format': re.compile(r'^[a-zA-Z].*@.*') }, 'api_key': { 'description': ('API key to retrieve IOCs via API calls. ' 'This should be a string of 40 alphanumeric characters.'), 'format': re.compile(r'^[a-zA-Z0-9]{40}$') } } creds_dict = { auth_key: user_input(info['description'], False, info['format']) for auth_key, info in required_creds.iteritems() } description = ('Required credentials for the Threat Intel Downloader') # Save these to the parameter store saved = save_parameter(region, ThreatStream.CRED_PARAMETER_NAME, creds_dict, description, overwrite) if saved: LOGGER_CLI.info( 'Threat Intel Downloader credentials were successfully ' 'saved to parameter store.') else: LOGGER_CLI.error( 'Threat Intel Downloader credentials were not saved to ' 'parameter store.') return saved
def terraform_generate(**kwargs): """Generate all Terraform plans for the configured clusters. Keyword Args: config [dict]: The loaded config from the 'conf/' directory init [bool]: Indicates if main.tf is generated for `terraform init` Returns: [bool]: Result of cluster generating """ config = kwargs.get('config') init = kwargs.get('init', False) # Setup main LOGGER_CLI.info('Generating cluster file: main.tf') main_json = json.dumps(generate_main(init=init, config=config), indent=2, sort_keys=True) with open('terraform/main.tf', 'w') as tf_file: tf_file.write(main_json) # Break out early during the init process, clusters aren't needed yet if init: return True # Setup clusters for cluster in config.clusters(): if cluster == 'main': raise InvalidClusterName( 'Rename cluster "main" to something else!') LOGGER_CLI.info('Generating cluster file: %s.tf', cluster) cluster_dict = generate_cluster(cluster_name=cluster, config=config) if not cluster_dict: LOGGER_CLI.error( 'An error was generated while creating the %s cluster', cluster) return False cluster_json = json.dumps(cluster_dict, indent=2, sort_keys=True) with open('terraform/{}.tf'.format(cluster), 'w') as tf_file: tf_file.write(cluster_json) return True
def _alarm_exists(self, alarm_name): """Check if this alarm name is already used somewhere. CloudWatch alarm names must be unique to an AWS account Args: alarm_name (str): The name of the alarm being created Returns: bool: True if the the alarm name is already present in the config """ message = ('CloudWatch metric alarm names must be unique ' 'within each AWS account. Please remove this alarm ' 'so it can be updated or choose another name.') funcs = {metrics.RULE_PROCESSOR_NAME} for func in funcs: for cluster in self.config['clusters']: func_alarms = ( self.config['clusters'][cluster]['modules']['stream_alert'][func].get( 'metric_alarms', {})) if alarm_name in func_alarms: LOGGER_CLI.error('An alarm with name \'%s\' already exists in the ' '\'conf/clusters/%s.json\' cluster. %s', alarm_name, cluster, message) return True global_config = self.config['global']['infrastructure'].get('monitoring') if not global_config: return False metric_alarms = global_config.get('metric_alarms') if not metric_alarms: return False # Check for functions saved in the global config. funcs.update({metrics.ALERT_PROCESSOR_NAME, metrics.ATHENA_PARTITION_REFRESH_NAME}) for func in funcs: global_func_alarms = global_config['metric_alarms'].get(func, {}) if alarm_name in global_func_alarms: LOGGER_CLI.error('An alarm with name \'%s\' already exists in the ' '\'conf/globals.json\'. %s', alarm_name, message) return True return False
def drop_all_tables(config): """Drop all 'streamalert' Athena tables Used when cleaning up an existing deployment Args: config (CLIConfig): Loaded StreamAlert CLI """ if not continue_prompt(message='Are you sure you want to drop all Athena tables?'): return athena_client = get_athena_client(config) if not athena_client.drop_all_tables(): LOGGER_CLI.error('Failed to drop one or more tables from database: %s', athena_client.database) else: LOGGER_CLI.info('Successfully dropped all tables from database: %s', athena_client.database)
def run_tests(options, context): """Actual protected function for running tests Args: options [namedtuple]: CLI options (debug, processor, etc) context [namedtuple]: A constructed aws context object """ if options.debug: LOGGER_SA.setLevel(logging.DEBUG) LOGGER_SO.setLevel(logging.DEBUG) LOGGER_CLI.setLevel(logging.DEBUG) else: # Add a filter to suppress a few noisy log messages LOGGER_SA.addFilter(TestingSuppressFilter()) # Check if the rule processor should be run for these tests test_rules = (set(run_options.get('processor')).issubset({'rule', 'all'}) or run_options.get('command') == 'live-test') # Check if the alert processor should be run for these tests test_alerts = (set(run_options.get('processor')).issubset({'alert', 'all'}) or run_options.get('command') == 'live-test') rule_proc_tester = RuleProcessorTester(context, test_rules) alert_proc_tester = AlertProcessorTester(context) # Run the rule processor for all rules or designated rule set for alerts in rule_proc_tester.test_processor(options.rules): # If the alert processor should be tested, process any alerts if test_alerts: alert_proc_tester.test_processor(alerts) # Report summary information for the alert processor if it was ran if test_alerts: AlertProcessorTester.report_output_summary() # Print any invalid log messages that we accumulated over this run for message in rule_proc_tester.invalid_log_messages: LOGGER_CLI.error('%s%s%s', COLOR_RED, message, COLOR_RESET) if not (rule_proc_tester.all_tests_passed and alert_proc_tester.all_tests_passed and (not rule_proc_tester.invalid_log_messages)): sys.exit(1)
def kms_encrypt(region, data, kms_key_alias): """Encrypt data with AWS KMS. Args: region (str): AWS region to use for boto3 client data (str): json string to be encrypted kms_key_alias (str): The KMS key alias to use for encryption of S3 objects Returns: str: Encrypted ciphertext data blob """ try: client = boto3.client('kms', region_name=region) response = client.encrypt(KeyId='alias/{}'.format(kms_key_alias), Plaintext=data) return response['CiphertextBlob'] except ClientError: LOGGER_CLI.error('An error occurred during credential encryption') raise
def report_output_summary(self): """Helper function to print the summary results of all tests""" failure_messages = [ item for item in self.status_messages if item.type == StatusMessage.FAILURE] warning_messages = [ item for item in self.status_messages if item.type == StatusMessage.WARNING] passed_tests = sum( 1 for item in self.status_messages if item.type == StatusMessage.SUCCESS) passed_tests = self.total_tests - len(failure_messages) # Print some lines at the bottom of output to make it more readable # This occurs here so there is always space and not only when the # successful test info prints print '\n\n' # Only print success info if we explicitly want to print output # but always print any errors or warnings below if self.print_output: # Print a message indicating how many of the total tests passed LOGGER_CLI.info('%s(%d/%d) Successful Tests%s', COLOR_GREEN, passed_tests, self.total_tests, COLOR_RESET) # Check if there were failed tests and report on them appropriately if failure_messages: # Print a message indicating how many of the total tests failed LOGGER_CLI.error('%s(%d/%d) Failures%s', COLOR_RED, len(failure_messages), self.total_tests, COLOR_RESET) # Iterate over the rule_name values in the failed list and report on them for index, failure in enumerate(failure_messages, start=1): LOGGER_CLI.error('%s(%d/%d) [%s] %s%s', COLOR_RED, index, len(failure_messages), failure.rule, failure.message, COLOR_RESET) # Check if there were any warnings and report on them if warning_messages: warning_count = len(warning_messages) LOGGER_CLI.warn('%s%d Warning%s%s', COLOR_YELLOW, warning_count, ('s' if warning_count > 1 else ''), COLOR_RESET) for index, warning in enumerate(warning_messages, start=1): LOGGER_CLI.warn('%s(%d/%d) [%s] %s%s', COLOR_YELLOW, index, warning_count, warning.rule, warning.message, COLOR_RESET)
def send_creds_to_s3(region, bucket, key, blob_data): """Put the encrypted credential blob for this service and destination in s3 Args: region (str): AWS region to use for boto3 client bucket (str): The name of the s3 bucket to write the encrypted credentials to key (str): ID for the s3 object to write the encrypted credentials to blob_data (bytes): Cipher text blob from the kms encryption """ try: client = boto3.client('s3', region_name=region) client.put_object(Body=blob_data, Bucket=bucket, Key=key) return True except ClientError as err: LOGGER_CLI.error( 'An error occurred while sending credentials to S3 for key \'%s\' ' 'in bucket \'%s\': %s', key, bucket, err.response['Error']['Message']) return False
def check_credentials(): """Check for valid AWS credentials in environment variables Returns: bool: True any of the AWS env variables exist """ aws_env_variables = [ 'AWS_PROFILE', 'AWS_SHARED_CREDENTIALS_FILE', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY' ] env_vars_exist = any( [env_var in os.environ for env_var in aws_env_variables]) if not env_vars_exist: LOGGER_CLI.error('No valid AWS Credentials found in your environment!') LOGGER_CLI.error('Please follow the setup instructions here: ' 'https://www.streamalert.io/account.html') return False return True
def check_credentials(): """Check for valid AWS credentials in environment variables Returns: bool: True any of the AWS env variables exist """ try: response = boto3.client('sts').get_caller_identity() except NoCredentialsError: LOGGER_CLI.error('No valid AWS Credentials found in your environment!') LOGGER_CLI.error('Please follow the setup instructions here: ' 'https://www.streamalert.io/getting-started.html' '#configure-aws-credentials') return False LOGGER_CLI.debug( 'Using credentials for user \'%s\' with user ID \'%s\' in account ' '\'%s\'', response['Arn'], response['UserId'], response['Account']) return True
def set_prefix(self, prefix): """Set the Org Prefix in Global settings""" if not isinstance(prefix, (unicode, str)): LOGGER_CLI.error('Invalid prefix type, must be string') return self.config['global']['account']['prefix'] = prefix self.config['global']['terraform']['tfstate_bucket'] = self.config[ 'global']['terraform']['tfstate_bucket'].replace( 'PREFIX_GOES_HERE', prefix) self.config['lambda']['alert_processor_config'][ 'source_bucket'] = self.config['lambda']['alert_processor_config'][ 'source_bucket'].replace('PREFIX_GOES_HERE', prefix) self.config['lambda']['rule_processor_config'][ 'source_bucket'] = self.config['lambda']['rule_processor_config'][ 'source_bucket'].replace('PREFIX_GOES_HERE', prefix) self.write() LOGGER_CLI.info('Prefix successfully configured')
def run_command(runner_args, **kwargs): """Helper function to run commands with error handling. Args: runner_args (list): Commands to run via subprocess kwargs: cwd (str): A path to execute commands from error_message (str): Message to show if command fails quiet (bool): Whether to show command output or hide it """ default_error_message = "An error occurred while running: {}".format( ' '.join(runner_args)) error_message = kwargs.get('error_message', default_error_message) default_cwd = 'terraform' cwd = kwargs.get('cwd', default_cwd) # Add the -force-copy flag for s3 state copying to suppress dialogs that # the user must type 'yes' into. if runner_args[0] == 'terraform': if runner_args[1] == 'init': runner_args.append('-force-copy') stdout_option = None if kwargs.get('quiet'): stdout_option = open(os.devnull, 'w') try: subprocess.check_call(runner_args, stdout=stdout_option, cwd=cwd) # nosec except subprocess.CalledProcessError as err: LOGGER_CLI.error('%s\n%s', error_message, err.cmd) return False except OSError as err: LOGGER_CLI.error('%s\n%s (%s)', error_message, err.strerror, runner_args[0]) return False return True
def generate_s3_events(cluster_name, cluster_dict, config): """Add the S3 Events module to the Terraform cluster dict. Args: cluster_name (str): The name of the currently generating cluster cluster_dict (defaultdict): The dict containing all Terraform config for a given cluster. config (dict): The loaded config from the 'conf/' directory Returns: bool: Result of applying the s3_events module """ modules = config['clusters'][cluster_name]['modules'] s3_event_buckets = modules['s3_events'] # Detect legacy and convert if isinstance(s3_event_buckets, dict) and 's3_bucket_id' in s3_event_buckets: del config['clusters'][cluster_name]['modules']['s3_events'] s3_event_buckets = [{'bucket_id': s3_event_buckets['s3_bucket_id']}] config['clusters'][cluster_name]['modules']['s3_events'] = s3_event_buckets LOGGER_CLI.info('Converting legacy S3 Events config') config.write() for bucket_info in s3_event_buckets: if 'bucket_id' not in bucket_info: LOGGER_CLI.error('Config Error: Missing bucket_id key from s3_event configuration') return False cluster_dict['module']['s3_events_{}'.format(bucket_info['bucket_id'].replace( '.', '_'))] = { 'source': 'modules/tf_stream_alert_s3_events', 'lambda_function_arn': '${{module.stream_alert_{}.lambda_arn}}'.format(cluster_name), 'bucket_id': bucket_info['bucket_id'], 'enable_events': bucket_info.get('enable_events', True), 'lambda_role_id': '${{module.stream_alert_{}.lambda_role_id}}'.format(cluster_name) } return True
def generate_cloudwatch_metric_alarms(cluster_name, cluster_dict, config): """Add the CloudWatch Metric Alarms information to the Terraform cluster dict. Args: cluster_name (str): The name of the currently generating cluster cluster_dict (defaultdict): The dict containing all Terraform config for a given cluster. config (dict): The loaded config from the 'conf/' directory """ infrastructure_config = config['global'].get('infrastructure') if not (infrastructure_config and 'monitoring' in infrastructure_config): LOGGER_CLI.error( 'Invalid config: Make sure you declare global infrastructure options!' ) return sns_topic_arn = monitoring_topic_arn(config) cluster_dict['module']['stream_alert_{}'.format( cluster_name)]['sns_topic_arn'] = sns_topic_arn stream_alert_config = config['clusters'][cluster_name]['modules'][ 'stream_alert'] # Add cluster metric alarms for the rule and alert processors formatted_alarms = [] for func_config in stream_alert_config.values(): if 'metric_alarms' not in func_config: continue # TODO: update this logic to simply use a list of maps once Terraform fixes # their support for this, instead of the comma-separated string this creates metric_alarms = func_config['metric_alarms'] for name, alarm_info in metric_alarms.iteritems(): formatted_alarms.append(_format_metric_alarm(name, alarm_info)) cluster_dict['module']['stream_alert_{}'.format( cluster_name)]['metric_alarms'] = formatted_alarms
def set_prefix(self, prefix): """Set the Org Prefix in Global settings""" if not isinstance(prefix, (unicode, str)): LOGGER_CLI.error('Invalid prefix type, must be string') return if '_' in prefix: LOGGER_CLI.error('Prefix cannot contain underscores') return tf_state_bucket = '{}.streamalert.terraform.state'.format(prefix) self.config['global']['account']['prefix'] = prefix self.config['global']['account'][ 'kms_key_alias'] = '{}_streamalert_secrets'.format(prefix) self.config['global']['terraform']['tfstate_bucket'] = tf_state_bucket self.config['lambda']['athena_partition_refresh_config'][ 'buckets'].clear() self.config['lambda']['athena_partition_refresh_config']['buckets'] \ ['{}.streamalerts'.format(prefix)] = 'alerts' self.write() LOGGER_CLI.info('Prefix successfully configured')
def send_creds_to_s3(region, bucket, key, blob_data): """Put the encrypted credential blob for this service and destination in s3 Args: region [string]: AWS region to use for boto3 client bucket [string]: The name of the s3 bucket to write the encrypted credentials to key [string]: ID for the s3 object to write the encrypted credentials to blob_data [bytes]: Cipher text blob from the kms encryption """ try: client = boto3.client('s3', region_name=region) client.put_object(Body=blob_data, Bucket=bucket, Key=key, ServerSideEncryption='AES256') return True except ClientError as err: LOGGER_CLI.error( 'An error occurred while sending credentials to S3 for key [%s]: ' '%s [%s]', key, err.response['Error']['Message'], err.response['Error']['BucketName']) return False
def toggle_metrics(self, enabled, clusters, lambda_functions): """Toggle CloudWatch metric logging and filter creation Args: enabled (bool): False if disabling metrics, true if enable_logging clusters (list): Clusters to enable or disable metrics on lambda_functions (list): Which lambda functions to enable or disable metrics on (rule, alert, or athena) """ for function in lambda_functions: if function == metrics.ATHENA_PARTITION_REFRESH_NAME: if 'athena_partition_refresh_config' in self.config['lambda']: self.config['lambda']['athena_partition_refresh_config'] \ ['enable_metrics'] = enabled else: LOGGER_CLI.error('No Athena configuration found; please initialize first.') continue for cluster in clusters: self.config['clusters'][cluster]['modules']['stream_alert'] \ [function]['enable_metrics'] = enabled self.write()
def save_app_auth_info(app, info, overwrite=False): """Function to add app auth information to parameter store Args: info (dict): Required values needed to save the requested authentication information to AWS Parameter Store """ # Get all of the required authentication values from the user for this app integration auth_dict = {auth_key: user_input(info['description'], False, info['format']) for auth_key, info in app.required_auth_info().iteritems()} description = ('Required authentication information for the \'{}\' service for ' 'use in the \'{}\' app'.format(info['type'], info['app_name'])) # Save these to the parameter store param_name = '{}_{}'.format(info['function_name'], AppConfig.AUTH_CONFIG_SUFFIX) saved = save_parameter(info['region'], param_name, auth_dict, description, overwrite) if saved: LOGGER_CLI.info('App authentication info successfully saved to parameter store.') else: LOGGER_CLI.error('App authentication info was not saved to parameter store.') return saved
def user_input(requested_info, mask, input_restrictions): """Prompt user for requested information Args: requested_info (str): Description of the information needed mask (bool): Decides whether to mask input or not Returns: str: response provided by the user """ # pylint: disable=protected-access response = '' prompt = '\nPlease supply {}: '.format(requested_info) if not mask: while not response: response = raw_input(prompt) # Restrict having spaces or colons in items (applies to things like # descriptors, etc) if isinstance(input_restrictions, re._pattern_type): if not input_restrictions.match(response): LOGGER_CLI.error( 'The supplied input should match the following ' 'regular expression: %s', input_restrictions.pattern) return user_input(requested_info, mask, input_restrictions) else: if any(x in input_restrictions for x in response): LOGGER_CLI.error( 'The supplied input should not contain any of the following: %s', '"{}"'.format('", "'.join(input_restrictions))) return user_input(requested_info, mask, input_restrictions) else: while not response: response = getpass(prompt=prompt) return response
def set_prefix(self, prefix): """Set the Org Prefix in Global settings""" if not isinstance(prefix, (unicode, str)): LOGGER_CLI.error('Invalid prefix type, must be string') return if '_' in prefix: LOGGER_CLI.error('Prefix cannot contain underscores') return tf_state_bucket = '{}.streamalert.terraform.state'.format(prefix) self.config['global']['account']['prefix'] = prefix self.config['global']['account']['kms_key_alias'] = '{}_streamalert_secrets'.format(prefix) self.config['global']['terraform']['tfstate_bucket'] = tf_state_bucket self.config['lambda']['athena_partition_refresh_config']['buckets'].clear() self.config['lambda']['athena_partition_refresh_config']['buckets'] \ ['{}.streamalerts'.format(prefix)] = 'alerts' lambda_funcs = [ 'alert_merger', 'alert_processor', 'athena_partition_refresh', 'rule_processor', 'stream_alert_apps', 'threat_intel_downloader' ] # Update all function configurations with the source streamalert source bucket info source_bucket = '{}.streamalert.source'.format(prefix) for func in lambda_funcs: func_config = '{}_config'.format(func) if func_config in self.config['lambda']: self.config['lambda'][func_config]['source_bucket'] = source_bucket self.write() LOGGER_CLI.info('Prefix successfully configured')
def set_prefix(self, prefix): """Set the Org Prefix in Global settings""" if not isinstance(prefix, (unicode, str)): LOGGER_CLI.error('Invalid prefix type, must be string') return if '_' in prefix: LOGGER_CLI.error('Prefix cannot contain underscores') return self.config['global']['account']['prefix'] = prefix self.config['global']['terraform']['tfstate_bucket'] = self.config[ 'global']['terraform']['tfstate_bucket'].replace( 'PREFIX_GOES_HERE', prefix) self.config['lambda']['alert_processor_config'][ 'source_bucket'] = self.config['lambda']['alert_processor_config'][ 'source_bucket'].replace('PREFIX_GOES_HERE', prefix) self.config['lambda']['rule_processor_config'][ 'source_bucket'] = self.config['lambda']['rule_processor_config'][ 'source_bucket'].replace('PREFIX_GOES_HERE', prefix) if self.config['lambda'].get('stream_alert_apps_config'): self.config['lambda']['stream_alert_apps_config'][ 'source_bucket'] = self.config['lambda'][ 'stream_alert_apps_config']['source_bucket'].replace( 'PREFIX_GOES_HERE', prefix) if self.config['lambda'].get('threat_intel_downloader_config'): self.config['lambda']['threat_intel_downloader_config']['source_bucket'] = \ self.config['lambda'][ 'threat_intel_downloader_config']['source_bucket'].replace( 'PREFIX_GOES_HERE', prefix ) self.write() LOGGER_CLI.info('Prefix successfully configured')
def _extract_precompiled_libs(self, temp_package_path): """Extract any precompiled third-party packages into the deployment package folder Args: temp_package_path (str): Full path to temp package path Returns: bool: False if the required libs were not found, True if otherwise """ # Return true immediately if there are no precompiled requirements for this package if not self.precompiled_libs: return True # Get any dependency files throughout the package folders that have # the _dependencies.zip suffix dependency_files = { package_file: os.path.join(root, package_file) for folder in self.package_folders for root, _, package_files in os.walk(folder) for package_file in package_files if package_file.endswith('_dependencies.zip') } for lib in self.precompiled_libs: libs_name = '_'.join([lib, 'dependencies.zip']) if libs_name not in dependency_files: LOGGER_CLI.error('Missing precompiled libs for package: %s', libs_name) return False # Copy the contents of the dependency zip to the package directory with zipfile.ZipFile(dependency_files[libs_name], 'r') as libs_file: libs_file.extractall(temp_package_path) return True
def _extract_precompiled_libs(self, temp_package_path): """Extract any precompiled third-party packages into the deployment package folder Args: temp_package_path (str): Full path to temp package path Returns: bool: True if precompiled libs were extracted successfully, False if some are missing """ dependency_files = { } # Map library name to location of its precompiled .zip file for path in self.package_files: if path.endswith('_dependencies.zip'): dependency_files[os.path.basename(path)] = path elif os.path.isdir(path): # Traverse directory looking for .zip files for root, _, package_files in os.walk(path): dependency_files.update({ package_file: os.path.join(root, package_file) for package_file in package_files if package_file.endswith('_dependencies.zip') }) for lib in self.precompiled_libs: libs_name = '_'.join([lib, 'dependencies.zip']) if libs_name not in dependency_files: LOGGER_CLI.error('Missing precompiled libs for package: %s', libs_name) return False # Copy the contents of the dependency zip to the package directory with zipfile.ZipFile(dependency_files[libs_name], 'r') as libs_file: libs_file.extractall(temp_package_path) return True
def user_input(requested_info, mask, input_restrictions): """Prompt user for requested information Args: requested_info (str): Description of the information needed mask (bool): Decides whether to mask input or not Returns: str: response provided by the user """ # pylint: disable=protected-access response = '' prompt = '\nPlease supply {}: '.format(requested_info) if not mask: while not response: response = raw_input(prompt) # Restrict having spaces or colons in items (applies to things like # descriptors, etc) valid_response = False if isinstance(input_restrictions, re._pattern_type): valid_response = input_restrictions.match(response) if not valid_response: LOGGER_CLI.error( 'The supplied input should match the following ' 'regular expression: %s', input_restrictions.pattern) elif callable(input_restrictions): # Functions can be passed here to perform complex validation of input # Transform the response with the validating function response = input_restrictions(response) valid_response = response is not None and response is not False if not valid_response: LOGGER_CLI.error( 'The supplied input failed to pass the validation ' 'function: %s', input_restrictions.__doc__) else: valid_response = not any(x in input_restrictions for x in response) if not valid_response: restrictions = ', '.join('\'{}\''.format(restriction) for restriction in input_restrictions) LOGGER_CLI.error( 'The supplied input should not contain any of the following: %s', restrictions) if not valid_response: return user_input(requested_info, mask, input_restrictions) else: while not response: response = getpass(prompt=prompt) return response
def _validate_options(options): if not options.interval: LOGGER_CLI.error('Missing command line argument --interval') return False if not options.timeout: LOGGER_CLI.error('Missing command line argument --timeout') return False if not options.memory: LOGGER_CLI.error('Missing command line argument --memory') return False return True
def rebuild_partitions(athena_client, options, config): """Rebuild an Athena table's partitions Steps: - Get the list of current partitions - Destroy existing table - Re-create tables - Re-create partitions Args: athena_client (boto3.client): Instantiated CLI AthenaClient options (namedtuple): The parsed args passed from the CLI config (CLIConfig): Loaded StreamAlert CLI """ if not options.table_name: LOGGER_CLI.error('Missing command line argument --table_name') return if not options.bucket: LOGGER_CLI.error('Missing command line argument --bucket') return sa_firehose = StreamAlertFirehose( config['global']['account']['region'], config['global']['infrastructure']['firehose'], config['logs']) sanitized_table_name = sa_firehose.firehose_log_name(options.table_name) if options.type == 'data': # Get the current set of partitions partition_success, partitions = athena_client.run_athena_query( query='SHOW PARTITIONS {}'.format(sanitized_table_name), database='streamalert') if not partition_success: LOGGER_CLI.error('An error occured when loading partitions for %s', sanitized_table_name) return unique_partitions = athena_helpers.unique_values_from_query(partitions) # Drop the table LOGGER_CLI.info('Dropping table %s', sanitized_table_name) drop_success, _ = athena_client.run_athena_query( query='DROP TABLE {}'.format(sanitized_table_name), database='streamalert') if not drop_success: LOGGER_CLI.error('An error occured when dropping the %s table', sanitized_table_name) return LOGGER_CLI.info('Dropped table %s', sanitized_table_name) new_partitions_statement = athena_helpers.partition_statement( unique_partitions, options.bucket, sanitized_table_name) # Make sure our new alter table statement is within the query API limits if len(new_partitions_statement) > MAX_QUERY_LENGTH: LOGGER_CLI.error( 'Partition statement too large, writing to local file') with open('partitions_{}.txt'.format(sanitized_table_name), 'w') as partition_file: partition_file.write(new_partitions_statement) return # Re-create the table with previous partitions options.refresh_type = 'add_hive_partition' create_table(athena_client, options, config) LOGGER_CLI.info('Creating %d new partitions for %s', len(unique_partitions), sanitized_table_name) new_part_success, _ = athena_client.run_athena_query( query=new_partitions_statement, database='streamalert') if not new_part_success: LOGGER_CLI.error('Error re-creating new partitions for %s', sanitized_table_name) return LOGGER_CLI.info('Successfully rebuilt partitions for %s', sanitized_table_name) else: LOGGER_CLI.info('Refreshing alerts tables unsupported')
def create_table(athena_client, options, config): """Create a 'streamalert' Athena table Args: athena_client (boto3.client): Instantiated CLI AthenaClient options (namedtuple): The parsed args passed from the CLI config (CLIConfig): Loaded StreamAlert CLI """ sa_firehose = StreamAlertFirehose( config['global']['account']['region'], config['global']['infrastructure']['firehose'], config['logs']) if not options.bucket: LOGGER_CLI.error('Missing command line argument --bucket') return if not options.refresh_type: LOGGER_CLI.error('Missing command line argument --refresh_type') return if options.type == 'data': if not options.table_name: LOGGER_CLI.error('Missing command line argument --table_name') return # Convert special characters in schema name to underscores sanitized_table_name = sa_firehose.firehose_log_name( options.table_name) # Check that the log type is enabled via Firehose if sanitized_table_name not in sa_firehose.enabled_logs: LOGGER_CLI.error( 'Table name %s missing from configuration or ' 'is not enabled.', sanitized_table_name) return # Check if the table exists if athena_client.check_table_exists(sanitized_table_name): LOGGER_CLI.info('The \'%s\' table already exists.', sanitized_table_name) return log_info = config['logs'][options.table_name.replace('_', ':', 1)] schema = dict(log_info['schema']) sanitized_schema = StreamAlertFirehose.sanitize_keys(schema) athena_schema = handler_helpers.to_athena_schema(sanitized_schema) # Add envelope keys to Athena Schema configuration_options = log_info.get('configuration') if configuration_options: envelope_keys = configuration_options.get('envelope_keys') if envelope_keys: sanitized_envelope_key_schema = StreamAlertFirehose.sanitize_keys( envelope_keys) # Note: this key is wrapped in backticks to be Hive compliant athena_schema[ '`streamalert:envelope_keys`'] = handler_helpers.to_athena_schema( sanitized_envelope_key_schema) # Handle Schema overrides # This is useful when an Athena schema needs to differ from the normal log schema if options.schema_override: for override in options.schema_override: if '=' not in override: LOGGER_CLI.error( 'Invalid schema override [%s], use column_name=type format', override) return column_name, column_type = override.split('=') if not all([column_name, column_type]): LOGGER_CLI.error( 'Invalid schema override [%s], use column_name=type format', override) # Columns are escaped to avoid Hive issues with special characters column_name = '`{}`'.format(column_name) if column_name in athena_schema: athena_schema[column_name] = column_type LOGGER_CLI.info('Applied schema override: %s:%s', column_name, column_type) else: LOGGER_CLI.error( 'Schema override column %s not found in Athena Schema, skipping', column_name) query = _construct_create_table_statement( schema=athena_schema, table_name=sanitized_table_name, bucket=options.bucket) elif options.type == 'alerts': if athena_client.check_table_exists(options.type): LOGGER_CLI.info('The \'alerts\' table already exists.') return query = ALERTS_TABLE_STATEMENT.format(bucket=options.bucket) if query: create_table_success, _ = athena_client.run_athena_query( query=query, database='streamalert') if create_table_success: # Update the CLI config config['lambda']['athena_partition_refresh_config'] \ ['refresh_type'][options.refresh_type][options.bucket] = options.type config.write() table_name = options.type if options.type == 'alerts' else sanitized_table_name LOGGER_CLI.info('The %s table was successfully created!', table_name)