예제 #1
0
def main(override_args=None):
    starter = BaseScripts()
    logger.debug(f'START: get_query_hash.py')

    # Load initial args
    parser = starter.start('Retrieve a query hash from a query body (a json used for the Advanced Search).')
    required_named = parser.add_argument_group('required arguments')
    required_named.add_argument(
        'query_body_path',
        help='path to the json file containing the query body',
    )
    if override_args:
        args = parser.parse_args(override_args)
    else:
        args = parser.parse_args()

    # Load api_endpoints and tokens
    endpoint_config, main_url, tokens = starter.load_config(args)
    with open(args.query_body_path, 'r') as query_body_file:
        query_body = json.load(query_body_file)
    logger.debug(f'Retrieving query hash for query body: {query_body}')

    advanced_search = AdvancedSearch(endpoint_config, args.env, tokens)

    response = advanced_search.get_threats(query_body, limit=0)
    if not response or 'query_hash' not in response:
        logger.error("Couldn't retrieve a query hash, is the query body valid ?")
        exit(1)
    query_hash = response['query_hash']
    if args.output:
        with open(args.output, 'w') as output:
            output.write(query_hash)
        logger.info(f'Query hash saved in {args.output}')
    else:
        logger.info(f'Query hash associated: {query_hash}')
예제 #2
0
    def _handle_bulk_search_task(self, task_uuid):
        retrieve_bulk_result_url = self._build_url_for_endpoint(
            'retrieve-bulk-search')
        retrieve_bulk_result_url = retrieve_bulk_result_url.format(
            task_uuid=task_uuid)

        start_time = time()
        back_off_time = 10

        json_response = None
        while not json_response:
            response = requests.get(url=retrieve_bulk_result_url,
                                    headers={'Authorization': self.tokens[0]})
            if response.status_code == 200:
                json_response = response.json()
            elif response.status_code == 401:
                logger.debug('Refreshing expired Token')
                self._token_update(response.json())
            elif time(
            ) - start_time + back_off_time < self.OCD_DTL_MAX_BULK_SEARCH_TIME:
                sleep(back_off_time)
                back_off_time = min(back_off_time * 2,
                                    self.OCD_DTL_MAX_BACK_OFF_TIME)
            else:
                logger.error()
                raise TimeoutError(
                    f'No bulk search result after waiting {self.OCD_DTL_MAX_BULK_SEARCH_TIME / 60:.0f} mins\n'
                    f'task_uuid: "{task_uuid}"')

        return json_response
예제 #3
0
    def handle_bulk_task(self, task_uuid, retrieve_bulk_result_url, *, timeout, additional_checks: List[Check] = None) \
            -> Json:
        """
        Handle a generic bulk task, blocking until the task is done or the timeout is up

        :param task_uuid: uuid of the bulk task
        :param retrieve_bulk_result_url: endpoint to query, must contained a task_uuid field
        :param timeout: timeout after which a TimeoutError is raised
        :param additional_checks: functions to call on a potential json, if all checks return True, the Json is returned
        :return: a Json returned on HTTP 200 validating all additional_checks
        """
        retrieve_bulk_result_url = retrieve_bulk_result_url.format(task_uuid=task_uuid)

        spinner = None
        if logger.isEnabledFor(logging.INFO):
            spinner = Halo(text=f'Waiting for bulk task {task_uuid} response', spinner='dots')
            spinner.start()

        start_time = time()
        back_off_time = 10

        json_response = None
        while not json_response:
            response = requests.get(
                url=retrieve_bulk_result_url,
                headers={'Authorization': self.tokens[0]},
                verify=self.requests_ssl_verify
            )
            if response.status_code == 200:
                potential_json_response = response.json()
                if additional_checks and not all(check(potential_json_response) for check in additional_checks):
                    continue  # the json isn't valid
                if spinner:
                    spinner.succeed(f'bulk task {task_uuid} done')
                json_response = potential_json_response
            elif response.status_code == 401:
                logger.debug('Refreshing expired Token')
                self._token_update(response.json())
            elif time() - start_time + back_off_time < timeout:
                sleep(back_off_time)
                back_off_time = min(back_off_time * 2, self.OCD_DTL_MAX_BACK_OFF_TIME)
            else:
                if spinner:
                    spinner.fail(f'bulk task {task_uuid} timeout')
                logger.error()
                raise TimeoutError(
                    f'No bulk result after waiting {timeout / 60:.0f} mins\n'
                    f'task_uuid: "{task_uuid}"'
                )

        if spinner:
            spinner.stop()
        return json_response
예제 #4
0
    def get_threats(self, query_hash: str = None, query_body: BaseEngine.Json = None, query_fields: List[str] = None) \
            -> dict:
        body = {"query_fields": query_fields}
        if query_body:
            body['query_body'] = self.build_full_query_body(query_body)
        else:
            body['query_hash'] = query_hash

        response = self.datalake_requests(self.url, 'post', post_body=body, headers=self._post_headers())
        if not response:
            logger.error('No bulk search created, is the query_hash valid as well as the query_fields ?')
            return {}
        return self._handle_bulk_search_task(task_uuid=response['task_uuid'])
예제 #5
0
 def get_threats(self,
                 query_hash: str,
                 query_fields: List[str] = None) -> dict:
     body = {"query_hash": query_hash, "query_fields": query_fields}
     response = self.datalake_requests(self.url,
                                       'post',
                                       post_body=body,
                                       headers=self._post_headers())
     if not response:
         logger.error(
             'No bulk search created, is the query_hash valid as well as the query_fields ?'
         )
         return {}
     return self._handle_bulk_search_task(task_uuid=response['task_uuid'])
예제 #6
0
    def load_config(self, args):
        """
        Load correct config and generate first tokens

        :return (dict, str, list<str, str>)
        """
        configure_logging(args.loglevel)
        endpoint_config = self._load_config(self._CONFIG_ENDPOINTS)
        main_url = endpoint_config['main'][args.env]
        token_generator = TokenGenerator(endpoint_config, environment=args.env)
        token_json = token_generator.get_token()
        if token_json:
            tokens = [f'Token {token_json["access_token"]}', f'Token {token_json["refresh_token"]}']
            return endpoint_config, main_url, tokens
        else:
            logger.error("Couldn't generate Tokens, please check the login/password provided")
            exit()
예제 #7
0
    def get_token(self):
        """
        Generate token from user input, with email and password
        """
        username = os.getenv('OCD_DTL_USERNAME') or input('Email: ')
        password = os.getenv('OCD_DTL_PASSWORD') or getpass()
        print()
        data = {'email': username, 'password': password}

        response = requests.post(url=self.url_token, json=data, verify=self.requests_ssl_verify)
        json_response = json.loads(response.text)
        if 'access_token' in json_response.keys():
            return json_response
        # else an error occurred

        logger.error(f'An error occurred while retrieving an access token, for URL: {self.url_token}\n'
                     f'response of the API: {response.text}')
        exit(1)
예제 #8
0
def get_atom_type_from_filename(filename, input_delimiter=':'):
    """
    parse filename for getting the atom type that it contains and the cleaned filename as a list as following
    ['type', cleaned_file]
    """
    parts = filename.split(input_delimiter, 1)

    # typed files
    if len(parts) == 2 and parts[0] in ATOM_TYPES_FLAGS:
        return parts

    # untyped files
    if len(parts) == 1:
        return [UNTYPED_ATOM_TYPE, parts[0]]

    logger.error(
        f'{filename} filename could not be treated `atomtype:path/to/file.txt`'
    )
    exit(1)
예제 #9
0
    def refresh_token(self, refresh_token: str):
        """
        Refresh the current token
        :param refresh_token: str
        """
        logger.debug('Token will be refresh')
        headers = {'Authorization': refresh_token}
        response = requests.post(url=self.url_refresh, headers=headers, verify=self.requests_ssl_verify)

        json_response = json.loads(response.text)
        if response.status_code == 401 and json_response.get('msg') == 'Token has expired':
            logger.debug('Refreshing the refresh token')
            # Refresh token is also expired, we need to restart the authentication from scratch
            return self.get_token()
        elif 'access_token' in json_response:
            return json_response
        # else an error occurred

        logger.error(f'An error occurred while refreshing the refresh token, for URL: {self.url_refresh}\n'
                     f'response of the API: {response.text}')
        exit(1)
예제 #10
0
    def datalake_requests(self,
                          url: str,
                          method: str,
                          headers: dict,
                          post_body: dict = None):
        """
        Use it to request the API
        """
        self.headers = headers
        tries_left = self.SET_MAX_RETRY

        logger.debug(
            self._pretty_debug_request(url, method, post_body, headers,
                                       self.tokens))

        if not headers.get('Authorization'):
            fresh_tokens = self.token_generator.get_token()
            self.tokens = [
                f'Token {fresh_tokens["access_token"]}',
                f'Token {fresh_tokens["refresh_token"]}'
            ]
            headers['Authorization'] = self.tokens[0]
        while True:
            response = self._send_request(url, method, headers, post_body)
            logger.debug(f'API response:\n{str(response.text)}')
            if response.status_code == 401:
                logger.warning(
                    'Token expired or Missing authorization header. Updating token'
                )
                self._token_update(self._load_response(response))
            elif response.status_code == 422:
                logger.warning('Bad authorization header. Updating token')
                logger.debug(f'422 HTTP code: {response.text}')
                self._token_update(self._load_response(response))
            elif response.status_code < 200 or response.status_code > 299:
                logger.error(
                    f'API returned non 2xx response code : {response.status_code}\n{response.text}'
                    f'\n Retrying')
            else:
                try:
                    dict_response = self._load_response(response)
                    return dict_response
                except JSONDecodeError:
                    logger.error(
                        'Request unexpectedly returned non dict value. Retrying'
                    )
            tries_left -= 1
            if tries_left <= 0:
                logger.error(
                    'Request failed: Will return nothing for this request')
                return {}
def main(override_args=None):
    """Method to start the script"""
    starter = BaseScripts()
    logger.debug(f'START: get_threats_from_query_hash.py')

    # Load initial args
    parser = starter.start(
        'Retrieve a list of response from a given query hash.')
    parser.add_argument(
        '--query_fields',
        help=
        'fields to be retrieved from the threat (default: only the hashkey)\n'
        'If an atom detail isn\'t present in a particular atom, empty string is returned.',
        nargs='+',
        default=['threat_hashkey'],
    )
    parser.add_argument(
        '--list',
        help=
        'Turn the output in a list (require query_fields to be a single element)',
        action='store_true',
    )
    required_named = parser.add_argument_group('required arguments')
    required_named.add_argument(
        'query_hash',
        help=
        'the query hash from which to retrieve the response hashkeys or a path to the query body json file',
    )
    if override_args:
        args = parser.parse_args(override_args)
    else:
        args = parser.parse_args()

    if len(args.query_fields) > 1 and args.list:
        parser.error(
            "List output format is only available if a single element is queried (via query_fields)"
        )

    query_body = {}
    query_hash = args.query_hash
    if len(query_hash) != 32 or os.path.exists(query_hash):
        try:
            with open(query_hash, 'r') as query_body_file:
                query_body = json.load(query_body_file)
        except FileNotFoundError:
            logger.error(
                f"Couldn't understand the given value as a query hash or path to query body: {query_hash}"
            )
            exit(1)

    # Load api_endpoints and tokens
    endpoint_config, main_url, tokens = starter.load_config(args)
    logger.debug(
        f'Start to search for threat from the query hash:{query_hash}')

    bulk_search = BulkSearch(endpoint_config, args.env, tokens)
    if query_body:
        response = bulk_search.get_threats(query_body=query_body,
                                           query_fields=args.query_fields)
    else:
        response = bulk_search.get_threats(query_hash=query_hash,
                                           query_fields=args.query_fields)
    original_count = response.get('count', 0)
    logger.info(f'Number of threat that have been retrieved: {original_count}')

    formatted_output = format_output(response, args.list)
    if args.output:
        with open(args.output, 'w') as output:
            output.write(formatted_output)
    else:
        logger.info(formatted_output)

    if args.output:
        logger.info(f'Threats saved in {args.output}')
    else:
        logger.info('Done')
예제 #12
0
def main(override_args=None):
    """Method to start the script"""
    starter = BaseScripts()

    # Load initial args
    parser = starter.start('Submit a new threat to Datalake from a file')
    required_named = parser.add_argument_group('required arguments')
    csv_controle = parser.add_argument_group('CSV control arguments')
    required_named.add_argument(
        '-i',
        '--input',
        help='read threats to add from FILE',
        required=True,
    )
    required_named.add_argument(
        '-a',
        '--atom_type',
        help='set it to define the atom type',
        required=True,
    )
    csv_controle.add_argument(
        '--is_csv',
        help='set if the file input is a CSV',
        action='store_true',
    )
    csv_controle.add_argument(
        '-d',
        '--delimiter',
        help='set the delimiter of the CSV file',
        default=',',
    )
    csv_controle.add_argument(
        '-c',
        '--column',
        help='select column of the CSV file, starting at 1',
        type=int,
        default=1,
    )
    parser.add_argument(
        '-p',
        '--public',
        help='set the visibility to public',
        action='store_true',
    )
    parser.add_argument(
        '-w',
        '--whitelist',
        help='set it to define the added threats as whitelist',
        action='store_true',
    )
    parser.add_argument(
        '-t',
        '--threat_types',
        nargs='+',
        help=
        'choose specific threat types and their score, like: ddos 50 scam 15',
        default=[],
    )
    parser.add_argument(
        '--tag',
        nargs='+',
        help='add a list of tags',
        default=[],
    )
    parser.add_argument(
        '--link',
        help='add link as external_analysis_link',
        nargs='+',
    )
    parser.add_argument(
        '--permanent',
        help=
        'sets override_type to permanent. Scores won\'t be updated by the algorithm. Default is temporary',
        action='store_true',
    )
    parser.add_argument(
        '--no-bulk',
        help=
        'force an api call for each threats, useful to retrieve the details of threats created',
        action='store_true',
    )
    if override_args:
        args = parser.parse_args(override_args)
    else:
        args = parser.parse_args()
    logger.debug(f'START: add_new_threats.py')

    if not args.threat_types and not args.whitelist:
        parser.error(
            "threat types is required if the atom is not for whitelisting")

    permanent = 'permanent' if args.permanent else 'temporary'

    if args.is_csv:
        try:
            list_new_threats = starter._load_csv(args.input, args.delimiter,
                                                 args.column - 1)
        except ValueError as ve:
            logger.error(ve)
            exit()
    else:
        list_new_threats = starter._load_list(args.input)
    list_new_threats = defang_threats(list_new_threats, args.atom_type)
    list_new_threats = list(OrderedDict.fromkeys(
        list_new_threats))  # removing duplicates while preserving order
    threat_types = ThreatsPost.parse_threat_types(args.threat_types) or []

    # Load api_endpoints and tokens
    endpoint_config, main_url, tokens = starter.load_config(args)
    if args.no_bulk:
        post_engine_add_threats = ThreatsPost(endpoint_config, args.env,
                                              tokens)
        response_dict = post_engine_add_threats.add_threats(
            list_new_threats, args.atom_type, args.whitelist, threat_types,
            args.public, args.tag, args.link, permanent)
    else:
        post_engine_add_threats = BulkThreatsPost(endpoint_config, args.env,
                                                  tokens)
        hashkeys = post_engine_add_threats.add_bulk_threats(
            list_new_threats, args.atom_type, args.whitelist, threat_types,
            args.public, args.tag, args.link, permanent)
        response_dict = {'haskeys': list(hashkeys)}

    if args.output:
        starter.save_output(args.output, response_dict)
        logger.debug(f'Results saved in {args.output}\n')
    logger.debug(f'END: add_new_threats.py')
예제 #13
0
def main(override_args=None):
    """Method to start the script"""
    starter = BaseScripts()

    # Load initial args
    parser = starter.start('Submit a new threat to Datalake from a file')
    required_named = parser.add_argument_group('required arguments')
    csv_controle = parser.add_argument_group('CSV control arguments')

    parser.add_argument(
        'threats',
        help='threats to lookup',
        nargs='*',
    )
    parser.add_argument(
        '-i',
        '--input',
        help='read threats to add from FILE',
    )
    parser.add_argument(
        '-td',
        '--threat_details',
        action='store_true',
        help='set if you also want to have access to the threat details ',
    )
    parser.add_argument(
        '-ot',
        '--output_type',
        default='json',
        help=
        'set to the output type desired {json,csv}. Default is json if not specified',
    )
    required_named.add_argument(
        '-a',
        '--atom_type',
        help='set it to define the atom type',
        required=True,
    )
    csv_controle.add_argument(
        '--is_csv',
        help='set if the file input is a CSV',
        action='store_true',
    )
    csv_controle.add_argument(
        '-d',
        '--delimiter',
        help='set the delimiter of the CSV file',
        default=',',
    )
    csv_controle.add_argument(
        '-c',
        '--column',
        help='select column of the CSV file, starting at 1',
        type=int,
        default=1,
    )
    if override_args:
        args = parser.parse_args(override_args)
    else:
        args = parser.parse_args()
    logger.debug(f'START: lookup_threats.py')

    if not args.threats and not args.input:
        parser.error("either a threat or an input_file is required")

    if args.atom_type not in PostEngine.authorized_atom_value:
        parser.error("atom type must be in {}".format(','.join(
            PostEngine.authorized_atom_value)))

    args.output_type = output_type2header(args.output_type, parser)
    hashkey_only = not args.threat_details
    # Load api_endpoints and tokens
    endpoint_config, main_url, tokens = starter.load_config(args)
    get_engine_lookup_threats = LookupThreats(endpoint_config, args.env,
                                              tokens)
    list_threats = list(args.threats) if args.threats else []
    if args.input:
        if args.is_csv:
            try:
                list_threats = list_threats + starter._load_csv(
                    args.input, args.delimiter, args.column - 1)
            except ValueError as ve:
                logger.error(ve)
                exit()
        else:
            list_threats = list_threats + starter._load_list(args.input)
    list_threats = list(OrderedDict.fromkeys(
        list_threats))  # removing duplicates while preserving order
    response_dict = get_engine_lookup_threats.lookup_threats(
        list_threats, args.atom_type, hashkey_only, args.output_type)

    if args.output:
        starter.save_output(args.output, response_dict)
        logger.debug(f'Results saved in {args.output}\n')
    logger.debug(f'END: lookup_threats.py')