class OpenSearchDataStore(object): """Implements the datastore.""" # Number of events to queue up when bulk inserting events. DEFAULT_FLUSH_INTERVAL = 1000 DEFAULT_SIZE = 100 DEFAULT_LIMIT = DEFAULT_SIZE # Max events to return DEFAULT_FROM = 0 DEFAULT_STREAM_LIMIT = 5000 # Max events to return when streaming results DEFAULT_FLUSH_RETRY_LIMIT = 3 # Max retries for flushing the queue. DEFAULT_EVENT_IMPORT_TIMEOUT = "3m" # Timeout value for importing events. def __init__(self, host="127.0.0.1", port=9200): """Create a OpenSearch client.""" super().__init__() self._error_container = {} self.user = current_app.config.get("OPENSEARCH_USER", "user") self.password = current_app.config.get("OPENSEARCH_PASSWORD", "pass") self.ssl = current_app.config.get("OPENSEARCH_SSL", False) self.verify = current_app.config.get("OPENSEARCH_VERIFY_CERTS", True) self.timeout = current_app.config.get("OPENSEARCH_TIMEOUT", 10) parameters = {} if self.ssl: parameters["use_ssl"] = self.ssl parameters["verify_certs"] = self.verify if self.user and self.password: parameters["http_auth"] = (self.user, self.password) if self.timeout: parameters["timeout"] = self.timeout self.client = OpenSearch([{"host": host, "port": port}], **parameters) self.import_counter = Counter() self.import_events = [] self._request_timeout = current_app.config.get( "TIMEOUT_FOR_EVENT_IMPORT", self.DEFAULT_EVENT_IMPORT_TIMEOUT) @staticmethod def _build_labels_query(sketch_id, labels): """Build OpenSearch query for Timesketch labels. Args: sketch_id: Integer of sketch primary key. labels: List of label names. Returns: OpenSearch query as a dictionary. """ label_query = {"bool": {"must": []}} for label in labels: # Increase metrics counter per label METRICS["search_filter_label"].labels(label=label).inc() nested_query = { "nested": { "query": { "bool": { "must": [ { "term": { "timesketch_label.name.keyword": label } }, { "term": { "timesketch_label.sketch_id": sketch_id } }, ] } }, "path": "timesketch_label", } } label_query["bool"]["must"].append(nested_query) return label_query @staticmethod def _build_events_query(events): """Build OpenSearch query for one or more document ids. Args: events: List of OpenSearch document IDs. Returns: OpenSearch query as a dictionary. """ events_list = [event["event_id"] for event in events] query_dict = {"query": {"ids": {"values": events_list}}} return query_dict @staticmethod def _build_query_dsl(query_dsl, timeline_ids): """Build OpenSearch Search DSL query by adding in timeline filtering. Args: query_dsl: A dict with the current query_dsl timeline_ids: Either a list of timeline IDs (int) or None. Returns: OpenSearch query DSL as a dictionary. """ # Remove any aggregation coming from user supplied Query DSL. # We have no way to display this data in a good way today. if query_dsl.get("aggregations", None): del query_dsl["aggregations"] if not timeline_ids: return query_dsl if not isinstance(timeline_ids, (list, tuple)): es_logger.error( "Attempting to pass in timelines to a query DSL, but the " "passed timelines are not a list.") return query_dsl if not all([isinstance(x, int) for x in timeline_ids]): es_logger.error("All timeline IDs need to be an integer.") return query_dsl old_query = query_dsl.get("query") if not old_query: return query_dsl query_dsl["query"] = { "bool": { "must": [], "should": [ { "bool": { "must": old_query, "must_not": [{ "exists": { "field": "__ts_timeline_id" }, }], } }, { "bool": { "must": [ { "terms": { "__ts_timeline_id": timeline_ids } }, old_query, ], "must_not": [], "filter": [{ "exists": { "field": "__ts_timeline_id" } }], } }, ], "must_not": [], "filter": [], } } return query_dsl @staticmethod def _convert_to_time_range(interval): """Convert an interval timestamp into start and end dates. Args: interval: Time frame representation Returns: Start timestamp in string format. End timestamp in string format. """ # return ('2018-12-05T00:00:00', '2018-12-05T23:59:59') TS_FORMAT = "%Y-%m-%dT%H:%M:%S" get_digits = lambda s: int("".join(filter(str.isdigit, s))) get_alpha = lambda s: "".join(filter(str.isalpha, s)) ts_parts = interval.split(" ") # The start date could be 1 or 2 first items start = " ".join(ts_parts[0:len(ts_parts) - 2]) minus = get_digits(ts_parts[-2]) plus = get_digits(ts_parts[-1]) interval = get_alpha(ts_parts[-1]) start_ts = parser.parse(start) rd = relativedelta.relativedelta if interval == "s": start_range = start_ts - rd(seconds=minus) end_range = start_ts + rd(seconds=plus) elif interval == "m": start_range = start_ts - rd(minutes=minus) end_range = start_ts + rd(minutes=plus) elif interval == "h": start_range = start_ts - rd(hours=minus) end_range = start_ts + rd(hours=plus) elif interval == "d": start_range = start_ts - rd(days=minus) end_range = start_ts + rd(days=plus) else: raise RuntimeError("Unable to parse the timestamp: " + str(interval)) return start_range.strftime(TS_FORMAT), end_range.strftime(TS_FORMAT) def build_query( self, sketch_id, query_string, query_filter, query_dsl=None, aggregations=None, timeline_ids=None, ): """Build OpenSearch DSL query. Args: sketch_id: Integer of sketch primary key query_string: Query string query_filter: Dictionary containing filters to apply query_dsl: Dictionary containing OpenSearch DSL query aggregations: Dict of OpenSearch aggregations timeline_ids: Optional list of IDs of Timeline objects that should be queried as part of the search. Returns: OpenSearch DSL query as a dictionary """ if query_dsl: if not isinstance(query_dsl, dict): query_dsl = json.loads(query_dsl) if not query_dsl: query_dsl = {} return self._build_query_dsl(query_dsl, timeline_ids) if query_filter.get("events", None): events = query_filter["events"] return self._build_events_query(events) query_dsl = { "query": { "bool": { "must": [], "must_not": [], "filter": [] } } } if query_string: query_dsl["query"]["bool"]["must"].append({ "query_string": { "query": query_string, "default_operator": "AND" } }) # New UI filters if query_filter.get("chips", None): labels = [] must_filters = query_dsl["query"]["bool"]["must"] must_not_filters = query_dsl["query"]["bool"]["must_not"] datetime_ranges = { "bool": { "should": [], "minimum_should_match": 1 } } for chip in query_filter["chips"]: # Exclude chips that the user disabled if not chip.get("active", True): continue # Increase metrics per chip type METRICS["search_filter_type"].labels(type=chip["type"]).inc() if chip["type"] == "label": labels.append(chip["value"]) elif chip["type"] == "term": term_filter = { "match_phrase": { "{}".format(chip["field"]): { "query": "{}".format(chip["value"]) } } } if chip["operator"] == "must": must_filters.append(term_filter) elif chip["operator"] == "must_not": must_not_filters.append(term_filter) elif chip["type"].startswith("datetime"): range_filter = lambda start, end: { "range": { "datetime": { "gte": start, "lte": end } } } if chip["type"] == "datetime_range": start, end = chip["value"].split(",") elif chip["type"] == "datetime_interval": start, end = self._convert_to_time_range(chip["value"]) else: continue datetime_ranges["bool"]["should"].append( range_filter(start, end)) label_filter = self._build_labels_query(sketch_id, labels) must_filters.append(label_filter) must_filters.append(datetime_ranges) # Pagination if query_filter.get("from", None): query_dsl["from"] = query_filter["from"] # Number of events to return if query_filter.get("size", None): query_dsl["size"] = query_filter["size"] # Make sure we are sorting. if not query_dsl.get("sort", None): query_dsl["sort"] = {"datetime": query_filter.get("order", "asc")} # Add any pre defined aggregations if aggregations: # post_filter happens after aggregation so we need to move the # filter to the query instead. if query_dsl.get("post_filter", None): query_dsl["query"]["bool"]["filter"] = query_dsl["post_filter"] query_dsl.pop("post_filter", None) query_dsl["aggregations"] = aggregations # TODO: Simplify this when we don't have to support both timelines # that have __ts_timeline_id set and those that don't. # (query_string AND timeline_id NOT EXISTS) OR ( # query_string AND timeline_id in LIST) if timeline_ids and isinstance(timeline_ids, (list, tuple)): must_filters_pre = copy.copy(query_dsl["query"]["bool"]["must"]) must_not_filters_pre = copy.copy( query_dsl["query"]["bool"]["must_not"]) must_filters_post = copy.copy(query_dsl["query"]["bool"]["must"]) must_not_filters_post = copy.copy( query_dsl["query"]["bool"]["must_not"]) must_not_filters_pre.append({ "exists": { "field": "__ts_timeline_id" }, }) must_filters_post.append( {"terms": { "__ts_timeline_id": timeline_ids }}) query_dsl["query"] = { "bool": { "must": [], "should": [ { "bool": { "must": must_filters_pre, "must_not": must_not_filters_pre, } }, { "bool": { "must": must_filters_post, "must_not": must_not_filters_post, "filter": [{ "exists": { "field": "__ts_timeline_id" } }], } }, ], "must_not": [], "filter": [], } } return query_dsl # pylint: disable=too-many-arguments def search( self, sketch_id, query_string, query_filter, query_dsl, indices, count=False, aggregations=None, return_fields=None, enable_scroll=False, timeline_ids=None, ): """Search OpenSearch. This will take a query string from the UI together with a filter definition. Based on this it will execute the search request on OpenSearch and get result back. Args: sketch_id: Integer of sketch primary key query_string: Query string query_filter: Dictionary containing filters to apply query_dsl: Dictionary containing OpenSearch DSL query indices: List of indices to query count: Boolean indicating if we should only return result count aggregations: Dict of OpenSearch aggregations return_fields: List of fields to return enable_scroll: If OpenSearch scroll API should be used timeline_ids: Optional list of IDs of Timeline objects that should be queried as part of the search. Returns: Set of event documents in JSON format """ scroll_timeout = None if enable_scroll: scroll_timeout = "1m" # Default to 1 minute scroll timeout # Exit early if we have no indices to query if not indices: return {"hits": {"hits": [], "total": 0}, "took": 0} # Make sure that the list of index names is uniq. indices = list(set(indices)) # Check if we have specific events to fetch and get indices. if query_filter.get("events", None): indices = { event["index"] for event in query_filter["events"] if event["index"] in indices } query_dsl = self.build_query( sketch_id=sketch_id, query_string=query_string, query_filter=query_filter, query_dsl=query_dsl, aggregations=aggregations, timeline_ids=timeline_ids, ) # Default search type for OpenSearch is query_then_fetch. search_type = "query_then_fetch" # Only return how many documents matches the query. if count: if "sort" in query_dsl: del query_dsl["sort"] try: count_result = self.client.count(body=query_dsl, index=list(indices)) except NotFoundError: es_logger.error( "Unable to count due to an index not found: {0:s}".format( ",".join(indices))) return 0 METRICS["search_requests"].labels(type="count").inc() return count_result.get("count", 0) if not return_fields: # Suppress the lint error because opensearchpy adds parameters # to the function with a decorator and this makes pylint sad. # pylint: disable=unexpected-keyword-arg return self.client.search( body=query_dsl, index=list(indices), search_type=search_type, scroll=scroll_timeout, ) # The argument " _source_include" changed to "_source_includes" in # ES version 7. This check add support for both version 6 and 7 clients. # pylint: disable=unexpected-keyword-arg try: if self.version.startswith("6"): _search_result = self.client.search( body=query_dsl, index=list(indices), search_type=search_type, _source_include=return_fields, scroll=scroll_timeout, ) else: _search_result = self.client.search( body=query_dsl, index=list(indices), search_type=search_type, _source_includes=return_fields, scroll=scroll_timeout, ) except RequestError as e: root_cause = e.info.get("error", {}).get("root_cause") if root_cause: error_items = [] for cause in root_cause: error_items.append("[{0:s}] {1:s}".format( cause.get("type", ""), cause.get("reason", ""))) cause = ", ".join(error_items) else: cause = str(e) es_logger.error("Unable to run search query: {0:s}".format(cause), exc_info=True) raise ValueError(cause) from e METRICS["search_requests"].labels(type="single").inc() return _search_result # pylint: disable=too-many-arguments def search_stream( self, sketch_id=None, query_string=None, query_filter=None, query_dsl=None, indices=None, return_fields=None, enable_scroll=True, timeline_ids=None, ): """Search OpenSearch. This will take a query string from the UI together with a filter definition. Based on this it will execute the search request on OpenSearch and get result back. Args : sketch_id: Integer of sketch primary key query_string: Query string query_filter: Dictionary containing filters to apply query_dsl: Dictionary containing OpenSearch DSL query indices: List of indices to query return_fields: List of fields to return enable_scroll: Boolean determining whether scrolling is enabled. timeline_ids: Optional list of IDs of Timeline objects that should be queried as part of the search. Returns: Generator of event documents in JSON format """ # Make sure that the list of index names is uniq. indices = list(set(indices)) METRICS["search_requests"].labels(type="stream").inc() if not query_filter.get("size"): query_filter["size"] = self.DEFAULT_STREAM_LIMIT if not query_filter.get("terminate_after"): query_filter["terminate_after"] = self.DEFAULT_STREAM_LIMIT result = self.search( sketch_id=sketch_id, query_string=query_string, query_dsl=query_dsl, query_filter=query_filter, indices=indices, return_fields=return_fields, enable_scroll=enable_scroll, timeline_ids=timeline_ids, ) if enable_scroll: scroll_id = result["_scroll_id"] scroll_size = result["hits"]["total"] else: scroll_id = None scroll_size = 0 # Elasticsearch version 7.x returns total hits as a dictionary. # TODO: Refactor when version 6.x has been deprecated. if isinstance(scroll_size, dict): scroll_size = scroll_size.get("value", 0) for event in result["hits"]["hits"]: yield event while scroll_size > 0: # pylint: disable=unexpected-keyword-arg result = self.client.scroll(scroll_id=scroll_id, scroll="5m") scroll_id = result["_scroll_id"] scroll_size = len(result["hits"]["hits"]) for event in result["hits"]["hits"]: yield event def get_filter_labels(self, sketch_id, indices): """Aggregate labels for a sketch. Args: sketch_id: The Sketch ID indices: List of indices to aggregate on Returns: List with label names. """ # This is a workaround to return all labels by setting the max buckets # to something big. If a sketch has more than this amount of labels # the list will be incomplete but it should be uncommon to have >10k # labels in a sketch. max_labels = 10000 # pylint: disable=line-too-long aggregation = { "aggs": { "nested": { "nested": { "path": "timesketch_label" }, "aggs": { "inner": { "filter": { "bool": { "must": [{ "term": { "timesketch_label.sketch_id": sketch_id } }] } }, "aggs": { "labels": { "terms": { "size": max_labels, "field": "timesketch_label.name.keyword", } } }, } }, } } } # Make sure that the list of index names is uniq. indices = list(set(indices)) labels = [] # pylint: disable=unexpected-keyword-arg try: result = self.client.search(index=indices, body=aggregation, size=0) except NotFoundError: es_logger.error("Unable to find the index/indices: {0:s}".format( ",".join(indices))) return labels buckets = (result.get("aggregations", {}).get("nested", {}).get("inner", {}).get("labels", {}).get("buckets", [])) for bucket in buckets: new_bucket = {} new_bucket["label"] = bucket.pop("key") new_bucket["count"] = bucket.pop("doc_count") labels.append(new_bucket) return labels # pylint: disable=inconsistent-return-statements def get_event(self, searchindex_id, event_id): """Get one event from the datastore. Args: searchindex_id: String of OpenSearch index id event_id: String of OpenSearch event id Returns: Event document in JSON format """ METRICS["search_get_event"].inc() try: # Suppress the lint error because opensearchpy adds parameters # to the function with a decorator and this makes pylint sad. # pylint: disable=unexpected-keyword-arg if self.version.startswith("6"): event = self.client.get( index=searchindex_id, id=event_id, doc_type="_all", _source_exclude=["timesketch_label"], ) else: event = self.client.get( index=searchindex_id, id=event_id, doc_type="_all", _source_excludes=["timesketch_label"], ) return event except NotFoundError: abort(HTTP_STATUS_CODE_NOT_FOUND) def count(self, indices): """Count number of documents. Args: indices: List of indices. Returns: Tuple containing number of documents and size on disk. """ if not indices: return 0, 0 # Make sure that the list of index names is uniq. indices = list(set(indices)) try: es_stats = self.client.indices.stats(index=indices, metric="docs, store") except NotFoundError: es_logger.error("Unable to count indices (index not found)") return 0, 0 except RequestError: es_logger.error("Unable to count indices (request error)", exc_info=True) return 0, 0 doc_count_total = (es_stats.get("_all", {}).get("primaries", {}).get("docs", {}).get("count", 0)) doc_bytes_total = (es_stats.get("_all", {}).get("primaries", {}).get( "store", {}).get("size_in_bytes", 0)) return doc_count_total, doc_bytes_total def set_label( self, searchindex_id, event_id, event_type, sketch_id, user_id, label, toggle=False, remove=False, single_update=True, ): """Set label on event in the datastore. Args: searchindex_id: String of OpenSearch index id event_id: String of OpenSearch event id event_type: String of OpenSearch document type sketch_id: Integer of sketch primary key user_id: Integer of user primary key label: String with the name of the label remove: Optional boolean value if the label should be removed toggle: Optional boolean value if the label should be toggled single_update: Boolean if the label should be indexed immediately. Returns: Dict with updated document body, or None if this is a single update. """ # OpenSearch painless script. update_body = { "script": { "lang": "painless", "source": UPDATE_LABEL_SCRIPT, "params": { "timesketch_label": { "name": str(label), "user_id": user_id, "sketch_id": sketch_id, }, remove: remove, }, } } if toggle: update_body["script"]["source"] = TOGGLE_LABEL_SCRIPT if not single_update: script = update_body["script"] return dict(source=script["source"], lang=script["lang"], params=script["params"]) doc = self.client.get(index=searchindex_id, id=event_id, doc_type="_all") try: doc["_source"]["timesketch_label"] except KeyError: doc = {"doc": {"timesketch_label": []}} self.client.update(index=searchindex_id, doc_type=event_type, id=event_id, body=doc) self.client.update(index=searchindex_id, id=event_id, doc_type=event_type, body=update_body) return None def create_index(self, index_name=uuid4().hex, doc_type="generic_event", mappings=None): """Create index with Timesketch settings. Args: index_name: Name of the index. Default is a generated UUID. doc_type: Name of the document type. Default id generic_event. mappings: Optional dict with the document mapping for OpenSearch. Returns: Index name in string format. Document type in string format. """ if mappings: _document_mapping = mappings else: _document_mapping = { "properties": { "timesketch_label": { "type": "nested" }, "datetime": { "type": "date" }, } } # TODO: Remove when we deprecate OpenSearch version 6.x if self.version.startswith("6"): _document_mapping = {doc_type: _document_mapping} if not self.client.indices.exists(index_name): try: self.client.indices.create( index=index_name, body={"mappings": _document_mapping}) except ConnectionError as e: raise RuntimeError( "Unable to connect to Timesketch backend.") from e except RequestError: index_exists = self.client.indices.exists(index_name) es_logger.warning( "Attempting to create an index that already exists " "({0:s} - {1:s})".format(index_name, str(index_exists))) return index_name, doc_type def delete_index(self, index_name): """Delete OpenSearch index. Args: index_name: Name of the index to delete. """ if self.client.indices.exists(index_name): try: self.client.indices.delete(index=index_name) except ConnectionError as e: raise RuntimeError( "Unable to connect to Timesketch backend: {}".format( e)) from e def import_event( self, index_name, event_type, event=None, event_id=None, flush_interval=DEFAULT_FLUSH_INTERVAL, timeline_id=None, ): """Add event to OpenSearch. Args: index_name: Name of the index in OpenSearch event_type: Type of event (e.g. plaso_event) event: Event dictionary event_id: Event OpenSearch ID flush_interval: Number of events to queue up before indexing timeline_id: Optional ID number of a Timeline object this event belongs to. If supplied an additional field will be added to the store indicating the timeline this belongs to. """ if event: for k, v in event.items(): if not isinstance(k, six.text_type): k = codecs.decode(k, "utf8") # Make sure we have decoded strings in the event dict. if isinstance(v, six.binary_type): v = codecs.decode(v, "utf8") event[k] = v # Header needed by OpenSearch when bulk inserting. header = { "index": { "_index": index_name, } } update_header = {"update": {"_index": index_name, "_id": event_id}} # TODO: Remove when we deprecate Elasticsearch version 6.x if self.version.startswith("6"): header["index"]["_type"] = event_type update_header["update"]["_type"] = event_type if event_id: # Event has "lang" defined if there is a script used for import. if event.get("lang"): event = {"script": event} else: event = {"doc": event} header = update_header if timeline_id: event["__ts_timeline_id"] = timeline_id self.import_events.append(header) self.import_events.append(event) self.import_counter["events"] += 1 if self.import_counter["events"] % int(flush_interval) == 0: _ = self.flush_queued_events() self.import_events = [] else: # Import the remaining events in the queue. if self.import_events: _ = self.flush_queued_events() return self.import_counter["events"] def flush_queued_events(self, retry_count=0): """Flush all queued events. Returns: dict: A dict object that contains the number of events that were sent to OpenSearch as well as information on whether there were any errors, and what the details of these errors if any. retry_count: optional int indicating whether this is a retry. """ if not self.import_events: return {} return_dict = { "number_of_events": len(self.import_events) / 2, "total_events": self.import_counter["events"], } try: # pylint: disable=unexpected-keyword-arg results = self.client.bulk(body=self.import_events, timeout=self._request_timeout) except (ConnectionTimeout, socket.timeout): if retry_count >= self.DEFAULT_FLUSH_RETRY_LIMIT: es_logger.error("Unable to add events, reached recount max.", exc_info=True) return {} es_logger.error("Unable to add events (retry {0:d}/{1:d})".format( retry_count, self.DEFAULT_FLUSH_RETRY_LIMIT)) return self.flush_queued_events(retry_count + 1) errors_in_upload = results.get("errors", False) return_dict["errors_in_upload"] = errors_in_upload if errors_in_upload: items = results.get("items", []) return_dict["errors"] = [] es_logger.error("Errors while attempting to upload events.") for item in items: index = item.get("index", {}) index_name = index.get("_index", "N/A") _ = self._error_container.setdefault(index_name, { "errors": [], "types": Counter(), "details": Counter() }) error_counter = self._error_container[index_name]["types"] error_detail_counter = self._error_container[index_name][ "details"] error_list = self._error_container[index_name]["errors"] error = index.get("error", {}) status_code = index.get("status", 0) doc_id = index.get("_id", "(unable to get doc id)") caused_by = error.get("caused_by", {}) caused_reason = caused_by.get("reason", "Unkown Detailed Reason") error_counter[error.get("type")] += 1 detail_msg = "{0:s}/{1:s}".format( caused_by.get("type", "Unknown Detailed Type"), " ".join(caused_reason.split()[:5]), ) error_detail_counter[detail_msg] += 1 error_msg = "<{0:s}> {1:s} [{2:s}/{3:s}]".format( error.get("type", "Unknown Type"), error.get("reason", "No reason given"), caused_by.get("type", "Unknown Type"), caused_reason, ) error_list.append(error_msg) try: es_logger.error( "Unable to upload document: {0:s} to index {1:s} - " "[{2:d}] {3:s}".format(doc_id, index_name, status_code, error_msg)) # We need to catch all exceptions here, since this is a crucial # call that we do not want to break operation. except Exception: # pylint: disable=broad-except es_logger.error( "Unable to upload document, and unable to log the " "error itself.", exc_info=True, ) return_dict["error_container"] = self._error_container self.import_events = [] return return_dict @property def version(self): """Get OpenSearch version. Returns: Version number as a string. """ version_info = self.client.info().get("version") return version_info.get("number")
class SearchApiClient: def __init__(self, host=settings.OPENSEARCH_HOST): protocol = settings.OPENSEARCH_PROTOCOL protocol_config = {} if protocol == "https": protocol_config = { "scheme": "https", "port": 443, "use_ssl": True, "verify_certs": settings.OPENSEARCH_VERIFY_CERTS, } if settings.IS_AWS: http_auth = ("supersurf", settings.OPENSEARCH_PASSWORD) else: http_auth = (None, None) self.client = OpenSearch([host], http_auth=http_auth, connection_class=RequestsHttpConnection, **protocol_config) self.index_nl = settings.OPENSEARCH_NL_INDEX self.index_en = settings.OPENSEARCH_EN_INDEX self.index_unk = settings.OPENSEARCH_UNK_INDEX self.languages = {"nl": self.index_nl, "en": self.index_en} @staticmethod def parse_search_result(search_result): """ Parses the search result into the correct format that the frontend uses :param search_result: result from search :return result: list of results ready for frontend """ hits = search_result.pop("hits") aggregations = search_result.get("aggregations", {}) result = dict() result['recordcount'] = hits['total']['value'] # Transform aggregations into drilldowns drilldowns = {} for aggregation_name, aggregation in aggregations.items(): buckets = aggregation["filtered"][ "buckets"] if "filtered" in aggregation else aggregation[ "buckets"] for bucket in buckets: drilldowns[f"{aggregation_name}-{bucket['key']}"] = bucket[ "doc_count"] result['drilldowns'] = drilldowns # Parse spelling suggestions did_you_mean = {} if 'suggest' in search_result: spelling_suggestion = search_result['suggest'][ 'did-you-mean-suggestion'][0] spelling_option = spelling_suggestion['options'][0] if len( spelling_suggestion['options']) else None if spelling_option is not None and spelling_option["score"] >= 0.01: did_you_mean = { 'original': spelling_suggestion['text'], 'suggestion': spelling_option['text'] } result['did_you_mean'] = did_you_mean # Transform hits into records result['records'] = [ SearchApiClient.parse_search_hit(hit) for hit in hits['hits'] ] return result @staticmethod def parse_search_hit(hit, transform=True): """ Parses the search hit into the format that is also used by the edurep endpoint. It's mostly just mapping the variables we need into the places that we expect them to be. :param hit: result from search :return record: parsed record """ data = hit["_source"] serializer = SearchResultSerializer() # Basic mapping between field and data (excluding any method fields with a source of "*") field_mapping = { field.source: field_name if transform else field.source for field_name, field in serializer.fields.items() if field.source != "*" } record = { field_mapping[field]: value for field, value in data.items() if field in field_mapping } # Reformatting some fields if a relations field is desired if "relations" in field_mapping: publishers = [{ "name": publisher } for publisher in data.get("publishers", [])] record["relations"] = { "authors": data.get("authors", []), "parties": data.get("parties", []) or publishers, "projects": data.get("projects", []), "keywords": [{ "label": keyword } for keyword in data.get("keywords", [])], "themes": [{ "label": theme } for theme in data.get("research_themes", [])], "parents": data.get("is_part_of", []), "children": data.get("has_parts", []) } # Calling methods on serializers to set data for method fields for field_name, field in serializer.fields.items(): if field.source != "*": continue record[field_name] = getattr(serializer, field.method_name)(data) # Add highlight to the record if hit.get("highlight", 0): record["highlight"] = hit["highlight"] return record def autocomplete(self, query): """ Use the suggest query to get typing hints during searching. :param query: the input from the user so far :return: a list of options matching the input query, sorted by length """ # build the query for search engine query_dictionary = { 'suggest': { "autocomplete": { 'text': query, "completion": { "field": "suggest_completion", "size": 100 } } } } result = self.client.search( index=[self.index_nl, self.index_en, self.index_unk], body=query_dictionary) # extract the options from the search result, remove duplicates, # remove non-matching prefixes (engine will suggest things that don't match _exactly_) # and sort by length autocomplete = result['suggest']['autocomplete'] options = autocomplete[0]['options'] flat_options = list( set([ item for option in options for item in option['_source']['suggest_completion'] ])) options_with_prefix = [ option for option in flat_options if option.lower().startswith(query.lower()) ] options_with_prefix.sort(key=lambda option: len(option)) return options_with_prefix def drilldowns(self, drilldown_names, search_text=None, filters=None): """ This function is named drilldowns is because it's also named drilldowns in the original edurep search code. It passes on information to search, and returns the search without the records. This allows calculation of 'item counts' (i.e. how many results there are in through a certain filter) """ search_results = self.search(search_text=search_text, filters=filters, drilldown_names=drilldown_names) search_results["records"] = [] return search_results def search(self, search_text, drilldown_names=None, filters=None, ordering=None, page=1, page_size=5): """ Build and send a query to search engine and parse it before returning. :param search_text: A list of strings to search for. :param drilldown_names: A list of the 'drilldowns' (filters) that are to be counted by engine. :param filters: The filters that are applied for this search. :param ordering: Sort the results by this ordering (or use default search ordering otherwise) :param page: The page index of the results :param page_size: How many items are loaded per page. :return: """ start_record = page_size * (page - 1) body = { 'query': { "bool": defaultdict(list) }, 'from': start_record, 'size': page_size, 'post_filter': { "bool": defaultdict(list) }, 'highlight': { 'number_of_fragments': 1, 'fragment_size': 120, 'fields': { 'description': {}, 'text': {} } } } if search_text: query_string = { "simple_query_string": { "fields": SEARCH_FIELDS, "query": search_text, "default_operator": "and" } } body["query"]["bool"]["must"] += [query_string] body["query"]["bool"]["should"] = { "distance_feature": { "field": "publisher_date", "pivot": "90d", "origin": "now", "boost": 1.15 } } body["suggest"] = { 'did-you-mean-suggestion': { 'text': search_text, 'phrase': { 'field': 'suggest_phrase', 'size': 1, 'gram_size': 3, 'direct_generator': [{ 'field': 'suggest_phrase', 'suggest_mode': 'always' }], }, } } indices = self.parse_index_language(self, filters) if drilldown_names: body["aggs"] = self.parse_aggregations(drilldown_names, filters) filters = self.parse_filters(filters) if filters: body["post_filter"]["bool"]["must"] += filters if ordering: body["sort"] = [self.parse_ordering(ordering), "_score"] # make query and parse result = self.client.search(index=indices, body=body) return self.parse_search_result(result) def get_materials_by_id(self, external_ids, page=1, page_size=10, **kwargs): """ Retrieve specific materials from search engine through their external id. :param external_ids: the id's of the materials to retrieve :param page: The page index of the results :param page_size: How many items are loaded per page. :return: a list of search results (like a regular search). """ start_record = page_size * (page - 1) normalized_external_ids = [] for external_id in external_ids: if not external_id.startswith("surf"): normalized_external_ids.append(external_id) else: external_id_parts = external_id.split(":") normalized_external_ids.append(external_id_parts[-1]) result = self.client.search( index=[self.index_nl, self.index_en, self.index_unk], body={ "query": { "bool": { "must": [{ "terms": { "external_id": normalized_external_ids } }] } }, 'from': start_record, 'size': page_size, }, ) results = self.parse_search_result(result) materials = { material["external_id"]: material for material in results["records"] } records = [] for external_id in normalized_external_ids: if external_id not in materials: continue records.append(materials[external_id]) results["recordcount"] = len(records) results["records"] = records return results def stats(self): stats = self.client.count( index=",".join([self.index_nl, self.index_en, self.index_unk])) return stats.get("count", 0) def more_like_this(self, external_id, language): index = self.languages.get(language, self.index_unk) body = { "query": { "more_like_this": { "fields": ["title", "description"], "like": [{ "_index": index, "_id": external_id }], "min_term_freq": 1, "max_query_terms": 12 } } } search_result = self.client.search(index=index, body=body) hits = search_result.pop("hits") result = dict() result["records_total"] = hits["total"]["value"] result["results"] = [ SearchApiClient.parse_search_hit(hit, transform=False) for hit in hits["hits"] ] return result def author_suggestions(self, author_name): body = { "query": { "bool": { "must": { "multi_match": { "fields": [ field for field in SEARCH_FIELDS if "authors" not in field ], "query": author_name, }, }, "must_not": { "match": { "authors.name.folded": author_name } } } } } search_result = self.client.search( index=[self.index_nl, self.index_en, self.index_unk], body=body) hits = search_result.pop("hits") result = dict() result["records_total"] = hits["total"]["value"] result["results"] = [ SearchApiClient.parse_search_hit(hit, transform=False) for hit in hits["hits"] ] return result @staticmethod def parse_filters(filters): """ Parse filters from the frontend format into the search engine format. Not every filter is handled by search engine in the same way so it's a lot of manual parsing. :param filters: the list of filters to be parsed :return: the filters in the format for a search query. """ if not filters: return {} filter_items = [] for filter_item in filters: # skip filter_items that are empty # and the language filter item (it's handled by telling search engine in what index to search). if not filter_item['items'] or 'language.keyword' in filter_item[ 'external_id']: continue search_type = filter_item['external_id'] # date range query if search_type == "publisher_date": lower_bound, upper_bound = filter_item["items"] if lower_bound is not None or upper_bound is not None: filter_items.append({ "range": { "publisher_date": { "gte": lower_bound, "lte": upper_bound } } }) # all other filter types are handled by just using terms with the 'translated' filter items else: filter_items.append( {"terms": { search_type: filter_item["items"] }}) return filter_items def parse_aggregations(self, aggregation_names, filters): """ Parse the aggregations so search engine can count the items properly. :param aggregation_names: the names of the aggregations to :param filters: the filters for the query :return: """ aggregation_items = {} for aggregation_name in aggregation_names: other_filters = [] if filters: other_filters = list( filter(lambda x: x['external_id'] != aggregation_name, filters)) other_filters = self.parse_filters(other_filters) search_type = aggregation_name if len(other_filters) > 0: # Filter the aggregation by the filters applied to other categories aggregation_items[aggregation_name] = { "filter": { "bool": { "must": other_filters } }, "aggs": { "filtered": { "terms": { "field": search_type, "size": 2000, } } }, } else: aggregation_items[aggregation_name] = { "terms": { "field": search_type, "size": 2000, } } return aggregation_items @staticmethod def parse_ordering(ordering): """ Parse the frontend ordering format ('asc', 'desc' or None) into the type that search engine expects. """ order = "asc" if ordering.startswith("-"): order = "desc" ordering = ordering[1:] search_type = ordering return {search_type: {"order": order}} @staticmethod def parse_index_language(self, filters): """ Select the index to search on based on language. """ # if no language is selected, search on both. indices = [self.index_nl, self.index_en, self.index_unk] if not filters: return indices language_item = [ filter_item for filter_item in filters if filter_item['external_id'] == 'language.keyword' ] if not language_item: return indices language_indices = [ f"latest-{language}" for language in language_item[0]['items'] ] return language_indices if len(language_indices) else indices