예제 #1
0
class ElasticsearchSampler():
    """Elasticsearchサンプルクラス
    """

    def __init__(self):
        host = 'localhost'
        port = 9200
        auth = ('admin', 'admin')
        # certs = 'esnode.pem'

        # Elasticsearchインタンスの作成
        self.es = OpenSearch(
            hosts=[{'host': host, 'port': port}],
            http_auth=auth,
            use_ssl=True,
            verify_certs=False,
            # ca_certs=certs,
            ssl_assert_hostname=False,
            ssl_show_warn=False,
        )

    def __del__(self):
        self.es.close()
        print("close elasticsearch instance--------------------------")

    def search(self, idx: str, query: str):
        """検索
        """
        result = self.es.search(index=idx, body=query)
        print('--[search]-------------------------------------------')
        pprint.pprint(result, sort_dicts=False)

    def bulk(self, index: str):
        """バルクインサート
        """

        try:
            # iterableなオブジェクトであればよいので以下どちらも可能
            # - ジェネレータで渡す
            success, failed = helpers.bulk(self.es, gendata3(index))
            # - list型で渡す
            # success, failed = helpers.bulk(self.es, bulklist())
        # except opensearchpy.ElasticsearchException as e:
        #     pprint.pprint(e)
        except Exception as e:
            pprint.pprint(e)
            return

        print('--[bulk  ]-------------------------------------------')
        pprint.pprint(success)
        pprint.pprint(failed)

    def delete_by_query(self, idx: str, query: str):
        """条件指定の削除
        """
        result = self.es.delete_by_query(index=idx, body=query)

        print(f'{type(result)}')
        print('--[delete_by_query]----------------------------------')
        pprint.pprint(result, sort_dicts=False)
예제 #2
0
def test_es_search():
    """Search after running management command to fill ES from data sources."""

    es = OpenSearch([{"host": "localhost", "port": 9200}])

    query = {
        "query": {
            "match": {
                "fi": {
                    "query": "kivist",
                    "fuzziness": "AUTO"
                }
            }
        }
    }

    s = es.search(index="test-index", body=query)

    hits = s["hits"]["total"]["value"]
    assert hits == 1
예제 #3
0
class OpenSearchDataStore(object):
    """Implements the datastore."""

    # Number of events to queue up when bulk inserting events.
    DEFAULT_FLUSH_INTERVAL = 1000
    DEFAULT_SIZE = 100
    DEFAULT_LIMIT = DEFAULT_SIZE  # Max events to return
    DEFAULT_FROM = 0
    DEFAULT_STREAM_LIMIT = 5000  # Max events to return when streaming results

    DEFAULT_FLUSH_RETRY_LIMIT = 3  # Max retries for flushing the queue.
    DEFAULT_EVENT_IMPORT_TIMEOUT = "3m"  # Timeout value for importing events.

    def __init__(self, host="127.0.0.1", port=9200):
        """Create a OpenSearch client."""
        super().__init__()
        self._error_container = {}

        self.user = current_app.config.get("OPENSEARCH_USER", "user")
        self.password = current_app.config.get("OPENSEARCH_PASSWORD", "pass")
        self.ssl = current_app.config.get("OPENSEARCH_SSL", False)
        self.verify = current_app.config.get("OPENSEARCH_VERIFY_CERTS", True)
        self.timeout = current_app.config.get("OPENSEARCH_TIMEOUT", 10)

        parameters = {}
        if self.ssl:
            parameters["use_ssl"] = self.ssl
            parameters["verify_certs"] = self.verify

        if self.user and self.password:
            parameters["http_auth"] = (self.user, self.password)
        if self.timeout:
            parameters["timeout"] = self.timeout

        self.client = OpenSearch([{"host": host, "port": port}], **parameters)

        self.import_counter = Counter()
        self.import_events = []
        self._request_timeout = current_app.config.get(
            "TIMEOUT_FOR_EVENT_IMPORT", self.DEFAULT_EVENT_IMPORT_TIMEOUT)

    @staticmethod
    def _build_labels_query(sketch_id, labels):
        """Build OpenSearch query for Timesketch labels.

        Args:
            sketch_id: Integer of sketch primary key.
            labels: List of label names.

        Returns:
            OpenSearch query as a dictionary.
        """
        label_query = {"bool": {"must": []}}

        for label in labels:
            # Increase metrics counter per label
            METRICS["search_filter_label"].labels(label=label).inc()
            nested_query = {
                "nested": {
                    "query": {
                        "bool": {
                            "must": [
                                {
                                    "term": {
                                        "timesketch_label.name.keyword": label
                                    }
                                },
                                {
                                    "term": {
                                        "timesketch_label.sketch_id": sketch_id
                                    }
                                },
                            ]
                        }
                    },
                    "path": "timesketch_label",
                }
            }
            label_query["bool"]["must"].append(nested_query)
        return label_query

    @staticmethod
    def _build_events_query(events):
        """Build OpenSearch query for one or more document ids.

        Args:
            events: List of OpenSearch document IDs.

        Returns:
            OpenSearch query as a dictionary.
        """
        events_list = [event["event_id"] for event in events]
        query_dict = {"query": {"ids": {"values": events_list}}}
        return query_dict

    @staticmethod
    def _build_query_dsl(query_dsl, timeline_ids):
        """Build OpenSearch Search DSL query by adding in timeline filtering.

        Args:
            query_dsl: A dict with the current query_dsl
            timeline_ids: Either a list of timeline IDs (int) or None.

        Returns:
            OpenSearch query DSL as a dictionary.
        """
        # Remove any aggregation coming from user supplied Query DSL.
        # We have no way to display this data in a good way today.
        if query_dsl.get("aggregations", None):
            del query_dsl["aggregations"]

        if not timeline_ids:
            return query_dsl

        if not isinstance(timeline_ids, (list, tuple)):
            es_logger.error(
                "Attempting to pass in timelines to a query DSL, but the "
                "passed timelines are not a list.")
            return query_dsl

        if not all([isinstance(x, int) for x in timeline_ids]):
            es_logger.error("All timeline IDs need to be an integer.")
            return query_dsl

        old_query = query_dsl.get("query")
        if not old_query:
            return query_dsl

        query_dsl["query"] = {
            "bool": {
                "must": [],
                "should": [
                    {
                        "bool": {
                            "must":
                            old_query,
                            "must_not": [{
                                "exists": {
                                    "field": "__ts_timeline_id"
                                },
                            }],
                        }
                    },
                    {
                        "bool": {
                            "must": [
                                {
                                    "terms": {
                                        "__ts_timeline_id": timeline_ids
                                    }
                                },
                                old_query,
                            ],
                            "must_not": [],
                            "filter": [{
                                "exists": {
                                    "field": "__ts_timeline_id"
                                }
                            }],
                        }
                    },
                ],
                "must_not": [],
                "filter": [],
            }
        }
        return query_dsl

    @staticmethod
    def _convert_to_time_range(interval):
        """Convert an interval timestamp into start and end dates.

        Args:
            interval: Time frame representation

        Returns:
            Start timestamp in string format.
            End timestamp in string format.
        """
        # return ('2018-12-05T00:00:00', '2018-12-05T23:59:59')
        TS_FORMAT = "%Y-%m-%dT%H:%M:%S"
        get_digits = lambda s: int("".join(filter(str.isdigit, s)))
        get_alpha = lambda s: "".join(filter(str.isalpha, s))

        ts_parts = interval.split(" ")
        # The start date could be 1 or 2 first items
        start = " ".join(ts_parts[0:len(ts_parts) - 2])
        minus = get_digits(ts_parts[-2])
        plus = get_digits(ts_parts[-1])
        interval = get_alpha(ts_parts[-1])

        start_ts = parser.parse(start)

        rd = relativedelta.relativedelta
        if interval == "s":
            start_range = start_ts - rd(seconds=minus)
            end_range = start_ts + rd(seconds=plus)
        elif interval == "m":
            start_range = start_ts - rd(minutes=minus)
            end_range = start_ts + rd(minutes=plus)
        elif interval == "h":
            start_range = start_ts - rd(hours=minus)
            end_range = start_ts + rd(hours=plus)
        elif interval == "d":
            start_range = start_ts - rd(days=minus)
            end_range = start_ts + rd(days=plus)
        else:
            raise RuntimeError("Unable to parse the timestamp: " +
                               str(interval))

        return start_range.strftime(TS_FORMAT), end_range.strftime(TS_FORMAT)

    def build_query(
        self,
        sketch_id,
        query_string,
        query_filter,
        query_dsl=None,
        aggregations=None,
        timeline_ids=None,
    ):
        """Build OpenSearch DSL query.

        Args:
            sketch_id: Integer of sketch primary key
            query_string: Query string
            query_filter: Dictionary containing filters to apply
            query_dsl: Dictionary containing OpenSearch DSL query
            aggregations: Dict of OpenSearch aggregations
            timeline_ids: Optional list of IDs of Timeline objects that should
                be queried as part of the search.

        Returns:
            OpenSearch DSL query as a dictionary
        """

        if query_dsl:
            if not isinstance(query_dsl, dict):
                query_dsl = json.loads(query_dsl)

            if not query_dsl:
                query_dsl = {}

            return self._build_query_dsl(query_dsl, timeline_ids)

        if query_filter.get("events", None):
            events = query_filter["events"]
            return self._build_events_query(events)

        query_dsl = {
            "query": {
                "bool": {
                    "must": [],
                    "must_not": [],
                    "filter": []
                }
            }
        }

        if query_string:
            query_dsl["query"]["bool"]["must"].append({
                "query_string": {
                    "query": query_string,
                    "default_operator": "AND"
                }
            })

        # New UI filters
        if query_filter.get("chips", None):
            labels = []
            must_filters = query_dsl["query"]["bool"]["must"]
            must_not_filters = query_dsl["query"]["bool"]["must_not"]
            datetime_ranges = {
                "bool": {
                    "should": [],
                    "minimum_should_match": 1
                }
            }

            for chip in query_filter["chips"]:
                # Exclude chips that the user disabled
                if not chip.get("active", True):
                    continue

                # Increase metrics per chip type
                METRICS["search_filter_type"].labels(type=chip["type"]).inc()
                if chip["type"] == "label":
                    labels.append(chip["value"])

                elif chip["type"] == "term":
                    term_filter = {
                        "match_phrase": {
                            "{}".format(chip["field"]): {
                                "query": "{}".format(chip["value"])
                            }
                        }
                    }

                    if chip["operator"] == "must":
                        must_filters.append(term_filter)

                    elif chip["operator"] == "must_not":
                        must_not_filters.append(term_filter)

                elif chip["type"].startswith("datetime"):
                    range_filter = lambda start, end: {
                        "range": {
                            "datetime": {
                                "gte": start,
                                "lte": end
                            }
                        }
                    }
                    if chip["type"] == "datetime_range":
                        start, end = chip["value"].split(",")
                    elif chip["type"] == "datetime_interval":
                        start, end = self._convert_to_time_range(chip["value"])
                    else:
                        continue
                    datetime_ranges["bool"]["should"].append(
                        range_filter(start, end))

            label_filter = self._build_labels_query(sketch_id, labels)
            must_filters.append(label_filter)
            must_filters.append(datetime_ranges)

        # Pagination
        if query_filter.get("from", None):
            query_dsl["from"] = query_filter["from"]

        # Number of events to return
        if query_filter.get("size", None):
            query_dsl["size"] = query_filter["size"]

        # Make sure we are sorting.
        if not query_dsl.get("sort", None):
            query_dsl["sort"] = {"datetime": query_filter.get("order", "asc")}

        # Add any pre defined aggregations
        if aggregations:
            # post_filter happens after aggregation so we need to move the
            # filter to the query instead.
            if query_dsl.get("post_filter", None):
                query_dsl["query"]["bool"]["filter"] = query_dsl["post_filter"]
                query_dsl.pop("post_filter", None)
            query_dsl["aggregations"] = aggregations

        # TODO: Simplify this when we don't have to support both timelines
        # that have __ts_timeline_id set and those that don't.
        # (query_string AND timeline_id NOT EXISTS) OR (
        #       query_string AND timeline_id in LIST)
        if timeline_ids and isinstance(timeline_ids, (list, tuple)):
            must_filters_pre = copy.copy(query_dsl["query"]["bool"]["must"])
            must_not_filters_pre = copy.copy(
                query_dsl["query"]["bool"]["must_not"])

            must_filters_post = copy.copy(query_dsl["query"]["bool"]["must"])
            must_not_filters_post = copy.copy(
                query_dsl["query"]["bool"]["must_not"])

            must_not_filters_pre.append({
                "exists": {
                    "field": "__ts_timeline_id"
                },
            })

            must_filters_post.append(
                {"terms": {
                    "__ts_timeline_id": timeline_ids
                }})

            query_dsl["query"] = {
                "bool": {
                    "must": [],
                    "should": [
                        {
                            "bool": {
                                "must": must_filters_pre,
                                "must_not": must_not_filters_pre,
                            }
                        },
                        {
                            "bool": {
                                "must":
                                must_filters_post,
                                "must_not":
                                must_not_filters_post,
                                "filter": [{
                                    "exists": {
                                        "field": "__ts_timeline_id"
                                    }
                                }],
                            }
                        },
                    ],
                    "must_not": [],
                    "filter": [],
                }
            }

        return query_dsl

    # pylint: disable=too-many-arguments
    def search(
        self,
        sketch_id,
        query_string,
        query_filter,
        query_dsl,
        indices,
        count=False,
        aggregations=None,
        return_fields=None,
        enable_scroll=False,
        timeline_ids=None,
    ):
        """Search OpenSearch. This will take a query string from the UI
        together with a filter definition. Based on this it will execute the
        search request on OpenSearch and get result back.

        Args:
            sketch_id: Integer of sketch primary key
            query_string: Query string
            query_filter: Dictionary containing filters to apply
            query_dsl: Dictionary containing OpenSearch DSL query
            indices: List of indices to query
            count: Boolean indicating if we should only return result count
            aggregations: Dict of OpenSearch aggregations
            return_fields: List of fields to return
            enable_scroll: If OpenSearch scroll API should be used
            timeline_ids: Optional list of IDs of Timeline objects that should
                be queried as part of the search.

        Returns:
            Set of event documents in JSON format
        """
        scroll_timeout = None
        if enable_scroll:
            scroll_timeout = "1m"  # Default to 1 minute scroll timeout

        # Exit early if we have no indices to query
        if not indices:
            return {"hits": {"hits": [], "total": 0}, "took": 0}

        # Make sure that the list of index names is uniq.
        indices = list(set(indices))

        # Check if we have specific events to fetch and get indices.
        if query_filter.get("events", None):
            indices = {
                event["index"]
                for event in query_filter["events"]
                if event["index"] in indices
            }

        query_dsl = self.build_query(
            sketch_id=sketch_id,
            query_string=query_string,
            query_filter=query_filter,
            query_dsl=query_dsl,
            aggregations=aggregations,
            timeline_ids=timeline_ids,
        )

        # Default search type for OpenSearch is query_then_fetch.
        search_type = "query_then_fetch"

        # Only return how many documents matches the query.
        if count:
            if "sort" in query_dsl:
                del query_dsl["sort"]
            try:
                count_result = self.client.count(body=query_dsl,
                                                 index=list(indices))
            except NotFoundError:
                es_logger.error(
                    "Unable to count due to an index not found: {0:s}".format(
                        ",".join(indices)))
                return 0
            METRICS["search_requests"].labels(type="count").inc()
            return count_result.get("count", 0)

        if not return_fields:
            # Suppress the lint error because opensearchpy adds parameters
            # to the function with a decorator and this makes pylint sad.
            # pylint: disable=unexpected-keyword-arg
            return self.client.search(
                body=query_dsl,
                index=list(indices),
                search_type=search_type,
                scroll=scroll_timeout,
            )

        # The argument " _source_include" changed to "_source_includes" in
        # ES version 7. This check add support for both version 6 and 7 clients.
        # pylint: disable=unexpected-keyword-arg
        try:
            if self.version.startswith("6"):
                _search_result = self.client.search(
                    body=query_dsl,
                    index=list(indices),
                    search_type=search_type,
                    _source_include=return_fields,
                    scroll=scroll_timeout,
                )
            else:
                _search_result = self.client.search(
                    body=query_dsl,
                    index=list(indices),
                    search_type=search_type,
                    _source_includes=return_fields,
                    scroll=scroll_timeout,
                )
        except RequestError as e:
            root_cause = e.info.get("error", {}).get("root_cause")
            if root_cause:
                error_items = []
                for cause in root_cause:
                    error_items.append("[{0:s}] {1:s}".format(
                        cause.get("type", ""), cause.get("reason", "")))
                cause = ", ".join(error_items)
            else:
                cause = str(e)

            es_logger.error("Unable to run search query: {0:s}".format(cause),
                            exc_info=True)
            raise ValueError(cause) from e

        METRICS["search_requests"].labels(type="single").inc()
        return _search_result

    # pylint: disable=too-many-arguments
    def search_stream(
        self,
        sketch_id=None,
        query_string=None,
        query_filter=None,
        query_dsl=None,
        indices=None,
        return_fields=None,
        enable_scroll=True,
        timeline_ids=None,
    ):
        """Search OpenSearch. This will take a query string from the UI
        together with a filter definition. Based on this it will execute the
        search request on OpenSearch and get result back.

        Args :
            sketch_id: Integer of sketch primary key
            query_string: Query string
            query_filter: Dictionary containing filters to apply
            query_dsl: Dictionary containing OpenSearch DSL query
            indices: List of indices to query
            return_fields: List of fields to return
            enable_scroll: Boolean determining whether scrolling is enabled.
            timeline_ids: Optional list of IDs of Timeline objects that should
                be queried as part of the search.

        Returns:
            Generator of event documents in JSON format
        """
        # Make sure that the list of index names is uniq.
        indices = list(set(indices))

        METRICS["search_requests"].labels(type="stream").inc()

        if not query_filter.get("size"):
            query_filter["size"] = self.DEFAULT_STREAM_LIMIT

        if not query_filter.get("terminate_after"):
            query_filter["terminate_after"] = self.DEFAULT_STREAM_LIMIT

        result = self.search(
            sketch_id=sketch_id,
            query_string=query_string,
            query_dsl=query_dsl,
            query_filter=query_filter,
            indices=indices,
            return_fields=return_fields,
            enable_scroll=enable_scroll,
            timeline_ids=timeline_ids,
        )

        if enable_scroll:
            scroll_id = result["_scroll_id"]
            scroll_size = result["hits"]["total"]
        else:
            scroll_id = None
            scroll_size = 0

        # Elasticsearch version 7.x returns total hits as a dictionary.
        # TODO: Refactor when version 6.x has been deprecated.
        if isinstance(scroll_size, dict):
            scroll_size = scroll_size.get("value", 0)

        for event in result["hits"]["hits"]:
            yield event

        while scroll_size > 0:
            # pylint: disable=unexpected-keyword-arg
            result = self.client.scroll(scroll_id=scroll_id, scroll="5m")
            scroll_id = result["_scroll_id"]
            scroll_size = len(result["hits"]["hits"])
            for event in result["hits"]["hits"]:
                yield event

    def get_filter_labels(self, sketch_id, indices):
        """Aggregate labels for a sketch.

        Args:
            sketch_id: The Sketch ID
            indices: List of indices to aggregate on

        Returns:
            List with label names.
        """
        # This is a workaround to return all labels by setting the max buckets
        # to something big. If a sketch has more than this amount of labels
        # the list will be incomplete but it should be uncommon to have >10k
        # labels in a sketch.
        max_labels = 10000

        # pylint: disable=line-too-long
        aggregation = {
            "aggs": {
                "nested": {
                    "nested": {
                        "path": "timesketch_label"
                    },
                    "aggs": {
                        "inner": {
                            "filter": {
                                "bool": {
                                    "must": [{
                                        "term": {
                                            "timesketch_label.sketch_id":
                                            sketch_id
                                        }
                                    }]
                                }
                            },
                            "aggs": {
                                "labels": {
                                    "terms": {
                                        "size": max_labels,
                                        "field":
                                        "timesketch_label.name.keyword",
                                    }
                                }
                            },
                        }
                    },
                }
            }
        }

        # Make sure that the list of index names is uniq.
        indices = list(set(indices))

        labels = []
        # pylint: disable=unexpected-keyword-arg
        try:
            result = self.client.search(index=indices,
                                        body=aggregation,
                                        size=0)
        except NotFoundError:
            es_logger.error("Unable to find the index/indices: {0:s}".format(
                ",".join(indices)))
            return labels

        buckets = (result.get("aggregations",
                              {}).get("nested",
                                      {}).get("inner",
                                              {}).get("labels",
                                                      {}).get("buckets", []))

        for bucket in buckets:
            new_bucket = {}
            new_bucket["label"] = bucket.pop("key")
            new_bucket["count"] = bucket.pop("doc_count")
            labels.append(new_bucket)
        return labels

    # pylint: disable=inconsistent-return-statements
    def get_event(self, searchindex_id, event_id):
        """Get one event from the datastore.

        Args:
            searchindex_id: String of OpenSearch index id
            event_id: String of OpenSearch event id

        Returns:
            Event document in JSON format
        """
        METRICS["search_get_event"].inc()
        try:
            # Suppress the lint error because opensearchpy adds parameters
            # to the function with a decorator and this makes pylint sad.
            # pylint: disable=unexpected-keyword-arg
            if self.version.startswith("6"):
                event = self.client.get(
                    index=searchindex_id,
                    id=event_id,
                    doc_type="_all",
                    _source_exclude=["timesketch_label"],
                )
            else:
                event = self.client.get(
                    index=searchindex_id,
                    id=event_id,
                    doc_type="_all",
                    _source_excludes=["timesketch_label"],
                )

            return event

        except NotFoundError:
            abort(HTTP_STATUS_CODE_NOT_FOUND)

    def count(self, indices):
        """Count number of documents.

        Args:
            indices: List of indices.

        Returns:
            Tuple containing number of documents and size on disk.
        """
        if not indices:
            return 0, 0

        # Make sure that the list of index names is uniq.
        indices = list(set(indices))

        try:
            es_stats = self.client.indices.stats(index=indices,
                                                 metric="docs, store")

        except NotFoundError:
            es_logger.error("Unable to count indices (index not found)")
            return 0, 0

        except RequestError:
            es_logger.error("Unable to count indices (request error)",
                            exc_info=True)
            return 0, 0

        doc_count_total = (es_stats.get("_all",
                                        {}).get("primaries",
                                                {}).get("docs",
                                                        {}).get("count", 0))
        doc_bytes_total = (es_stats.get("_all", {}).get("primaries", {}).get(
            "store", {}).get("size_in_bytes", 0))

        return doc_count_total, doc_bytes_total

    def set_label(
        self,
        searchindex_id,
        event_id,
        event_type,
        sketch_id,
        user_id,
        label,
        toggle=False,
        remove=False,
        single_update=True,
    ):
        """Set label on event in the datastore.

        Args:
            searchindex_id: String of OpenSearch index id
            event_id: String of OpenSearch event id
            event_type: String of OpenSearch document type
            sketch_id: Integer of sketch primary key
            user_id: Integer of user primary key
            label: String with the name of the label
            remove: Optional boolean value if the label should be removed
            toggle: Optional boolean value if the label should be toggled
            single_update: Boolean if the label should be indexed immediately.

        Returns:
            Dict with updated document body, or None if this is a single update.
        """
        # OpenSearch painless script.
        update_body = {
            "script": {
                "lang": "painless",
                "source": UPDATE_LABEL_SCRIPT,
                "params": {
                    "timesketch_label": {
                        "name": str(label),
                        "user_id": user_id,
                        "sketch_id": sketch_id,
                    },
                    remove: remove,
                },
            }
        }

        if toggle:
            update_body["script"]["source"] = TOGGLE_LABEL_SCRIPT

        if not single_update:
            script = update_body["script"]
            return dict(source=script["source"],
                        lang=script["lang"],
                        params=script["params"])

        doc = self.client.get(index=searchindex_id,
                              id=event_id,
                              doc_type="_all")
        try:
            doc["_source"]["timesketch_label"]
        except KeyError:
            doc = {"doc": {"timesketch_label": []}}
            self.client.update(index=searchindex_id,
                               doc_type=event_type,
                               id=event_id,
                               body=doc)

        self.client.update(index=searchindex_id,
                           id=event_id,
                           doc_type=event_type,
                           body=update_body)

        return None

    def create_index(self,
                     index_name=uuid4().hex,
                     doc_type="generic_event",
                     mappings=None):
        """Create index with Timesketch settings.

        Args:
            index_name: Name of the index. Default is a generated UUID.
            doc_type: Name of the document type. Default id generic_event.
            mappings: Optional dict with the document mapping for OpenSearch.

        Returns:
            Index name in string format.
            Document type in string format.
        """
        if mappings:
            _document_mapping = mappings
        else:
            _document_mapping = {
                "properties": {
                    "timesketch_label": {
                        "type": "nested"
                    },
                    "datetime": {
                        "type": "date"
                    },
                }
            }

        # TODO: Remove when we deprecate OpenSearch version 6.x
        if self.version.startswith("6"):
            _document_mapping = {doc_type: _document_mapping}

        if not self.client.indices.exists(index_name):
            try:
                self.client.indices.create(
                    index=index_name, body={"mappings": _document_mapping})
            except ConnectionError as e:
                raise RuntimeError(
                    "Unable to connect to Timesketch backend.") from e
            except RequestError:
                index_exists = self.client.indices.exists(index_name)
                es_logger.warning(
                    "Attempting to create an index that already exists "
                    "({0:s} - {1:s})".format(index_name, str(index_exists)))

        return index_name, doc_type

    def delete_index(self, index_name):
        """Delete OpenSearch index.

        Args:
            index_name: Name of the index to delete.
        """
        if self.client.indices.exists(index_name):
            try:
                self.client.indices.delete(index=index_name)
            except ConnectionError as e:
                raise RuntimeError(
                    "Unable to connect to Timesketch backend: {}".format(
                        e)) from e

    def import_event(
        self,
        index_name,
        event_type,
        event=None,
        event_id=None,
        flush_interval=DEFAULT_FLUSH_INTERVAL,
        timeline_id=None,
    ):
        """Add event to OpenSearch.

        Args:
            index_name: Name of the index in OpenSearch
            event_type: Type of event (e.g. plaso_event)
            event: Event dictionary
            event_id: Event OpenSearch ID
            flush_interval: Number of events to queue up before indexing
            timeline_id: Optional ID number of a Timeline object this event
                belongs to. If supplied an additional field will be added to
                the store indicating the timeline this belongs to.
        """
        if event:
            for k, v in event.items():
                if not isinstance(k, six.text_type):
                    k = codecs.decode(k, "utf8")

                # Make sure we have decoded strings in the event dict.
                if isinstance(v, six.binary_type):
                    v = codecs.decode(v, "utf8")

                event[k] = v

            # Header needed by OpenSearch when bulk inserting.
            header = {
                "index": {
                    "_index": index_name,
                }
            }
            update_header = {"update": {"_index": index_name, "_id": event_id}}

            # TODO: Remove when we deprecate Elasticsearch version 6.x
            if self.version.startswith("6"):
                header["index"]["_type"] = event_type
                update_header["update"]["_type"] = event_type

            if event_id:
                # Event has "lang" defined if there is a script used for import.
                if event.get("lang"):
                    event = {"script": event}
                else:
                    event = {"doc": event}
                header = update_header

            if timeline_id:
                event["__ts_timeline_id"] = timeline_id

            self.import_events.append(header)
            self.import_events.append(event)
            self.import_counter["events"] += 1

            if self.import_counter["events"] % int(flush_interval) == 0:
                _ = self.flush_queued_events()
                self.import_events = []
        else:
            # Import the remaining events in the queue.
            if self.import_events:
                _ = self.flush_queued_events()

        return self.import_counter["events"]

    def flush_queued_events(self, retry_count=0):
        """Flush all queued events.

        Returns:
            dict: A dict object that contains the number of events
                that were sent to OpenSearch as well as information
                on whether there were any errors, and what the
                details of these errors if any.
            retry_count: optional int indicating whether this is a retry.
        """
        if not self.import_events:
            return {}

        return_dict = {
            "number_of_events": len(self.import_events) / 2,
            "total_events": self.import_counter["events"],
        }

        try:
            # pylint: disable=unexpected-keyword-arg
            results = self.client.bulk(body=self.import_events,
                                       timeout=self._request_timeout)
        except (ConnectionTimeout, socket.timeout):
            if retry_count >= self.DEFAULT_FLUSH_RETRY_LIMIT:
                es_logger.error("Unable to add events, reached recount max.",
                                exc_info=True)
                return {}

            es_logger.error("Unable to add events (retry {0:d}/{1:d})".format(
                retry_count, self.DEFAULT_FLUSH_RETRY_LIMIT))
            return self.flush_queued_events(retry_count + 1)

        errors_in_upload = results.get("errors", False)
        return_dict["errors_in_upload"] = errors_in_upload

        if errors_in_upload:
            items = results.get("items", [])
            return_dict["errors"] = []

            es_logger.error("Errors while attempting to upload events.")
            for item in items:
                index = item.get("index", {})
                index_name = index.get("_index", "N/A")

                _ = self._error_container.setdefault(index_name, {
                    "errors": [],
                    "types": Counter(),
                    "details": Counter()
                })

                error_counter = self._error_container[index_name]["types"]
                error_detail_counter = self._error_container[index_name][
                    "details"]
                error_list = self._error_container[index_name]["errors"]

                error = index.get("error", {})
                status_code = index.get("status", 0)
                doc_id = index.get("_id", "(unable to get doc id)")
                caused_by = error.get("caused_by", {})

                caused_reason = caused_by.get("reason",
                                              "Unkown Detailed Reason")

                error_counter[error.get("type")] += 1
                detail_msg = "{0:s}/{1:s}".format(
                    caused_by.get("type", "Unknown Detailed Type"),
                    " ".join(caused_reason.split()[:5]),
                )
                error_detail_counter[detail_msg] += 1

                error_msg = "<{0:s}> {1:s} [{2:s}/{3:s}]".format(
                    error.get("type", "Unknown Type"),
                    error.get("reason", "No reason given"),
                    caused_by.get("type", "Unknown Type"),
                    caused_reason,
                )
                error_list.append(error_msg)
                try:
                    es_logger.error(
                        "Unable to upload document: {0:s} to index {1:s} - "
                        "[{2:d}] {3:s}".format(doc_id, index_name, status_code,
                                               error_msg))
                # We need to catch all exceptions here, since this is a crucial
                # call that we do not want to break operation.
                except Exception:  # pylint: disable=broad-except
                    es_logger.error(
                        "Unable to upload document, and unable to log the "
                        "error itself.",
                        exc_info=True,
                    )

        return_dict["error_container"] = self._error_container

        self.import_events = []
        return return_dict

    @property
    def version(self):
        """Get OpenSearch version.

        Returns:
          Version number as a string.
        """
        version_info = self.client.info().get("version")
        return version_info.get("number")
예제 #4
0
class OpenSearchDataStore():
    """Implements the datastore."""

    # Number of events to queue up when bulk inserting events.
    DEFAULT_FLUSH_INTERVAL = 20000
    DEFAULT_SIZE = 1000  # Max events to return

    def __init__(self, host='127.0.0.1', port=9200, url=None):
        """Create an OpenSearch client."""
        super().__init__()
        if url:
            self.client = OpenSearch([url], timeout=30)
        else:
            self.client = OpenSearch([{
                'host': host,
                'port': port
            }],
                                     timeout=30)
        self.import_counter = collections.Counter()
        self.import_events = []

    @staticmethod
    def build_query(query_string):
        """Build OpenSearch DSL query.

    Args:
      query_string: Query string

    Returns:
      OpenSearch DSL query as a dictionary
    """

        query_dsl = {
            'query': {
                'bool': {
                    'must': [{
                        'query_string': {
                            'query': query_string
                        }
                    }]
                }
            }
        }

        return query_dsl

    def create_index(self, index_name):
        """Create an index.

    Args:
      index_name: Name of the index

    Returns:
      Index name in string format.
    """
        if not self.client.indices.exists(index_name):
            try:
                self.client.indices.create(index=index_name)
            except exceptions.ConnectionError as e:
                raise RuntimeError(
                    'Unable to connect to backend datastore.') from e

        return index_name

    def delete_index(self, index_name):
        """Delete OpenSearch index.

    Args:
      index_name: Name of the index to delete.
    """
        if self.client.indices.exists(index_name):
            try:
                self.client.indices.delete(index=index_name)
            except exceptions.ConnectionError as e:
                raise RuntimeError(
                    'Unable to connect to backend datastore.') from e

    def import_event(self,
                     index_name,
                     event=None,
                     flush_interval=DEFAULT_FLUSH_INTERVAL):
        """Add event to OpenSearch.

    Args:
      index_name: Name of the index in OpenSearch
      event: Event dictionary
      flush_interval: Number of events to queue up before indexing

    Returns:
      The number of events processed.
    """
        if event:
            # Header needed by OpenSearch when bulk inserting.
            header = {'index': {'_index': index_name}}

            self.import_events.append(header)
            self.import_events.append(event)
            self.import_counter['events'] += 1

            if self.import_counter['events'] % int(flush_interval) == 0:
                self.client.bulk(body=self.import_events)
                self.import_events = []
        else:
            # Import the remaining events in the queue.
            if self.import_events:
                self.client.bulk(body=self.import_events)

        return self.import_counter['events']

    def index_exists(self, index_name):
        """Check if an index already exists.

    Args:
      index_name: Name of the index

    Returns:
      True if the index exists, False if not.
    """
        return self.client.indices.exists(index_name)

    def search(self, index_id, query_string, size=DEFAULT_SIZE):
        """Search OpenSearch.

    This will take a query string from the UI together with a filter definition.
    Based on this it will execute the search request on OpenSearch and get the
    result back.

    Args:
      index_id: Index to be searched
      query_string: Query string
      size: Maximum number of results to return

    Returns:
      Set of event documents in JSON format
    """

        query_dsl = self.build_query(query_string)

        # Default search type for OpenSearch is query_then_fetch.
        search_type = 'query_then_fetch'

        # pylint: disable=unexpected-keyword-arg
        return self.client.search(body=query_dsl,
                                  index=index_id,
                                  size=size,
                                  search_type=search_type)
예제 #5
0
def search(
    client: OpenSearch,
    index: Optional[str] = "_all",
    search_body: Optional[Dict[str, Any]] = None,
    doc_type: Optional[str] = None,
    is_scroll: Optional[bool] = False,
    filter_path: Optional[Union[str, Collection[str]]] = None,
    **kwargs: Any,
) -> pd.DataFrame:
    """Return results matching query DSL as pandas dataframe.

    Parameters
    ----------
    client : OpenSearch
        instance of opensearchpy.OpenSearch to use.
    index : str, optional
        A comma-separated list of index names to search.
        use `_all` or empty string to perform the operation on all indices.
    search_body : Dict[str, Any], optional
        The search definition using the `Query DSL <https://opensearch.org/docs/opensearch/query-dsl/full-text/>`_.
    doc_type : str, optional
        Name of the document type (for Elasticsearch versions 5.x and earlier).
    is_scroll : bool, optional
        Allows to retrieve a large numbers of results from a single search request using
        `scroll <https://opensearch.org/docs/opensearch/rest-api/scroll/>`_
        for example, for machine learning jobs.
        Because scroll search contexts consume a lot of memory, we suggest you don’t use the scroll operation
        for frequent user queries.
    filter_path : Union[str, Collection[str]], optional
        Use the filter_path parameter to reduce the size of the OpenSearch Service response \
(default: ['hits.hits._id','hits.hits._source'])
    **kwargs :
        KEYWORD arguments forwarded to `opensearchpy.OpenSearch.search \
<https://opensearch-py.readthedocs.io/en/latest/api.html#opensearchpy.OpenSearch.search>`_
        and also to `opensearchpy.helpers.scan <https://opensearch-py.readthedocs.io/en/master/helpers.html#scan>`_
         if `is_scroll=True`

    Returns
    -------
    Union[pandas.DataFrame, Iterator[pandas.DataFrame]]
        Results as Pandas DataFrame

    Examples
    --------
    Searching an index using query DSL

    >>> import awswrangler as wr
    >>> client = wr.opensearch.connect(host='DOMAIN-ENDPOINT')
    >>> df = wr.opensearch.search(
    ...         client=client,
    ...         index='movies',
    ...         search_body={
    ...           "query": {
    ...             "match": {
    ...               "title": "wind"
    ...             }
    ...           }
    ...         }
    ...      )


    """
    if doc_type:
        kwargs["doc_type"] = doc_type

    if filter_path is None:
        filter_path = ["hits.hits._id", "hits.hits._source"]

    if is_scroll:
        if isinstance(filter_path, str):
            filter_path = [filter_path]
        filter_path = ["_scroll_id", "_shards"] + list(
            filter_path)  # required for scroll
        documents_generator = scan(client,
                                   index=index,
                                   query=search_body,
                                   filter_path=filter_path,
                                   **kwargs)
        documents = [_hit_to_row(doc) for doc in documents_generator]
        df = pd.DataFrame(documents)
    else:
        response = client.search(index=index,
                                 body=search_body,
                                 filter_path=filter_path,
                                 **kwargs)
        df = _search_response_to_df(response)
    return df
예제 #6
0
class ElasticSearchDB(object):
    """
    .. class:: ElasticSearchDB

    :param str url: the url to the database for example: el.cern.ch:9200
    :param str gDebugFile: is used to save the debug information to a file
    :param int timeout: the default time out to Elasticsearch
    :param int RESULT_SIZE: The number of data points which will be returned by the query.
    """

    __url = ""
    __timeout = 120
    clusterName = ""
    RESULT_SIZE = 10000

    ########################################################################
    def __init__(
        self,
        host,
        port,
        user=None,
        password=None,
        indexPrefix="",
        useSSL=True,
        useCRT=False,
        ca_certs=None,
        client_key=None,
        client_cert=None,
    ):
        """c'tor

        :param self: self reference
        :param str host: name of the database for example: MonitoringDB
        :param str port: The full name of the database for example: 'Monitoring/MonitoringDB'
        :param str user: user name to access the db
        :param str password: if the db is password protected we need to provide a password
        :param str indexPrefix: it is the indexPrefix used to get all indexes
        :param bool useSSL: We can disable using secure connection. By default we use secure connection.
        :param bool useCRT: Use certificates.
        :param str ca_certs: CA certificates bundle.
        :param str client_key: Client key.
        :param str client_cert: Client certificate.
        """

        self.__indexPrefix = indexPrefix
        self._connected = False
        if user and password:
            sLog.debug("Specified username and password")
            if port:
                self.__url = "https://%s:%s@%s:%d" % (user, password, host,
                                                      port)
            else:
                self.__url = "https://%s:%s@%s" % (user, password, host)
        else:
            sLog.debug("Username and password not specified")
            if port:
                self.__url = "http://%s:%d" % (host, port)
            else:
                self.__url = "http://%s" % host

        if port:
            sLog.verbose("Connecting to %s:%s, useSSL = %s" %
                         (host, port, useSSL))
        else:
            sLog.verbose("Connecting to %s, useSSL = %s" % (host, useSSL))

        if useSSL:
            if ca_certs:
                casFile = ca_certs
            else:
                bd = BundleDeliveryClient()
                retVal = bd.getCAs()
                casFile = None
                if not retVal["OK"]:
                    sLog.error("CAs file does not exists:", retVal["Message"])
                    casFile = certifi.where()
                else:
                    casFile = retVal["Value"]

            self.client = Elasticsearch(self.__url,
                                        timeout=self.__timeout,
                                        use_ssl=True,
                                        verify_certs=True,
                                        ca_certs=casFile)
        elif useCRT:
            self.client = Elasticsearch(
                self.__url,
                timeout=self.__timeout,
                use_ssl=True,
                verify_certs=True,
                ca_certs=ca_certs,
                client_cert=client_cert,
                client_key=client_key,
            )
        else:
            self.client = Elasticsearch(self.__url, timeout=self.__timeout)

        # Before we use the database we try to connect
        # and retrieve the cluster name

        try:
            if self.client.ping():
                # Returns True if the cluster is running, False otherwise
                result = self.client.info()
                self.clusterName = result.get("cluster_name", " ")  # pylint: disable=no-member
                sLog.info("Database info\n", json.dumps(result, indent=4))
                self._connected = True
            else:
                sLog.error("Cannot ping ElasticsearchDB!")
        except ConnectionError as e:
            sLog.error(repr(e))

    ########################################################################
    def getIndexPrefix(self):
        """
        It returns the DIRAC setup.
        """
        return self.__indexPrefix

    ########################################################################
    @ifConnected
    def query(self, index, query):
        """Executes a query and returns its result (uses ES DSL language).

        :param self: self reference
        :param str index: index name
        :param dict query: It is the query in ElasticSearch DSL language

        """
        try:
            esDSLQueryResult = self.client.search(index=index, body=query)
            return S_OK(esDSLQueryResult)
        except RequestError as re:
            return S_ERROR(re)

    @ifConnected
    def update(self, index, query=None, updateByQuery=True, id=None):
        """Executes an update of a document, and returns S_OK/S_ERROR

        :param self: self reference
        :param str index: index name
        :param dict query: It is the query in ElasticSearch DSL language
        :param bool updateByQuery: A bool to determine update by query or index values using index function.
        :param int id: ID for the document to be created.

        """

        sLog.debug("Updating %s with %s, updateByQuery=%s, id=%s" %
                   (index, query, updateByQuery, id))

        if not index or not query:
            return S_ERROR("Missing index or query")

        try:
            if updateByQuery:
                esDSLQueryResult = self.client.update_by_query(index=index,
                                                               body=query)
            else:
                esDSLQueryResult = self.client.index(index=index,
                                                     body=query,
                                                     id=id)
            return S_OK(esDSLQueryResult)
        except RequestError as re:
            return S_ERROR(re)

    @ifConnected
    def _Search(self, indexname):
        """
        it returns the object which can be used for retreiving certain value from the DB
        """
        return Search(using=self.client, index=indexname)

    ########################################################################
    def _Q(self, name_or_query="match", **params):
        """
        It is a wrapper to ElasticDSL Query module used to create a query object.
        :param str name_or_query is the type of the query
        """
        return Q(name_or_query, **params)

    def _A(self, name_or_agg, aggsfilter=None, **params):
        """
        It is a wrapper to ElasticDSL aggregation module, used to create an aggregation
        """
        return A(name_or_agg, aggsfilter, **params)

    ########################################################################
    @ifConnected
    def getIndexes(self, indexName=None):
        """
        It returns the available indexes...
        """
        if not indexName:
            indexName = self.__indexPrefix
        sLog.debug("Getting indices alias of %s" % indexName)
        # we only return indexes which belong to a specific prefix for example 'lhcb-production' or 'dirac-production etc.
        return list(self.client.indices.get_alias("%s*" % indexName))

    ########################################################################
    @ifConnected
    def getDocTypes(self, indexName):
        """
        Returns mappings, by index.

        :param str indexName: is the name of the index...
        :return: S_OK or S_ERROR
        """
        result = []
        try:
            sLog.debug("Getting mappings for ", indexName)
            result = self.client.indices.get_mapping(indexName)
        except Exception as e:  # pylint: disable=broad-except
            sLog.exception()
            return S_ERROR(e)

        doctype = ""
        for indexConfig in result:
            if not result[indexConfig].get("mappings"):
                # there is a case when the mapping exits and the value is None...
                # this is usually an empty index or a corrupted index.
                sLog.warn("Index does not have mapping %s!" % indexConfig)
                continue
            if result[indexConfig].get("mappings"):
                doctype = result[indexConfig]["mappings"]
                break  # we suppose the mapping of all indexes are the same...

        if not doctype:
            return S_ERROR("%s does not exists!" % indexName)

        return S_OK(doctype)

    ########################################################################
    @ifConnected
    def existingIndex(self, indexName):
        """
        Checks the existance of an index, by its name

        :param str indexName: the name of the index
        :returns: S_OK/S_ERROR if the request is successful
        """
        sLog.debug("Checking existance of index %s" % indexName)
        try:
            return S_OK(self.client.indices.exists(indexName))
        except TransportError as e:
            sLog.exception()
            return S_ERROR(e)

    ########################################################################

    @ifConnected
    def createIndex(self, indexPrefix, mapping=None, period="day"):
        """
        :param str indexPrefix: it is the index name.
        :param dict mapping: the configuration of the index.
        :param str period: We can specify, which kind of index will be created.
                           Currently only daily and monthly indexes are supported.

        """
        if period is not None:
            fullIndex = self.generateFullIndexName(
                indexPrefix,
                period)  # we have to create an index each period...
        else:
            sLog.warn(
                "The period is not provided, so using non-periodic indexes names"
            )
            fullIndex = indexPrefix

        res = self.existingIndex(fullIndex)
        if not res["OK"]:
            return res
        elif res["Value"]:
            return S_OK(fullIndex)

        try:
            sLog.info("Create index: ", fullIndex + str(mapping))
            self.client.indices.create(index=fullIndex,
                                       body={"mappings": mapping})  # ES7

            return S_OK(fullIndex)
        except Exception as e:  # pylint: disable=broad-except
            sLog.error("Can not create the index:", repr(e))
            return S_ERROR("Can not create the index")

    @ifConnected
    def deleteIndex(self, indexName):
        """
        :param str indexName: the name of the index to be deleted...
        """
        sLog.info("Deleting index", indexName)
        try:
            retVal = self.client.indices.delete(indexName)
        except NotFoundError:
            sLog.warn("Index does not exist", indexName)
            return S_OK("Noting to delete")
        except ValueError as e:
            return S_ERROR(DErrno.EVALUE, e)

        if retVal.get("acknowledged"):
            # if the value exists and the value is not None
            sLog.info("Deleted index", indexName)
            return S_OK(indexName)

        return S_ERROR(retVal)

    def index(self, indexName, body=None, docID=None, op_type="index"):
        """
        :param str indexName: the name of the index to be used
        :param dict body: the data which will be indexed (basically the JSON)
        :param int id: optional document id
        :param str op_type: Explicit operation type. (options: 'index' (default) or 'create')
        :return: the index name in case of success.
        """

        sLog.debug("Indexing in %s body %s, id=%s" % (indexName, body, docID))

        if not indexName or not body:
            return S_ERROR("Missing index or body")

        try:
            res = self.client.index(index=indexName,
                                    body=body,
                                    id=docID,
                                    params={"op_type": op_type})
        except (RequestError, TransportError) as e:
            sLog.exception()
            return S_ERROR(e)

        if res.get("created") or res.get("result") in ("created", "updated"):
            # the created index exists but the value can be None.
            return S_OK(indexName)

        return S_ERROR(res)

    @ifConnected
    def bulk_index(self,
                   indexPrefix,
                   data=None,
                   mapping=None,
                   period="day",
                   withTimeStamp=True):
        """
        :param str indexPrefix: index name.
        :param list data: contains a list of dictionary
        :param dict mapping: the mapping used by elasticsearch
        :param str period: Accepts 'day' and 'month'. We can specify which kind of indexes will be created.
        :param bool withTimeStamp: add timestamp to data, if not there already.

        :returns: S_OK/S_ERROR
        """
        sLog.verbose("Bulk indexing",
                     "%d records will be inserted" % len(data))
        if mapping is None:
            mapping = {}

        if period is not None:
            indexName = self.generateFullIndexName(indexPrefix, period)
        else:
            indexName = indexPrefix
        sLog.debug("Bulk indexing into %s of %s" % (indexName, data))

        res = self.existingIndex(indexName)
        if not res["OK"]:
            return res
        if not res["Value"]:
            retVal = self.createIndex(indexPrefix, mapping, period)
            if not retVal["OK"]:
                return retVal

        try:
            res = bulk(client=self.client,
                       index=indexName,
                       actions=generateDocs(data, withTimeStamp))
        except (BulkIndexError, RequestError) as e:
            sLog.exception()
            return S_ERROR(e)

        if res[0] == len(data):
            # we have inserted all documents...
            return S_OK(len(data))
        else:
            return S_ERROR(res)

    @ifConnected
    def getUniqueValue(self, indexName, key, orderBy=False):
        """
        :param str indexName: the name of the index which will be used for the query
        :param dict orderBy: it is a dictionary in case we want to order the result {key:'desc'} or {key:'asc'}
        :returns: a list of unique value for a certain key from the dictionary.
        """

        query = self._Search(indexName)

        endDate = datetime.utcnow()

        startDate = endDate - timedelta(days=30)

        timeFilter = self._Q(
            "range",
            timestamp={
                "lte": int(TimeUtilities.toEpoch(endDate)) * 1000,
                "gte": int(TimeUtilities.toEpoch(startDate)) * 1000,
            },
        )
        query = query.filter("bool", must=timeFilter)
        if orderBy:
            query.aggs.bucket(key,
                              "terms",
                              field=key,
                              size=self.RESULT_SIZE,
                              order=orderBy).metric(key,
                                                    "cardinality",
                                                    field=key)
        else:
            query.aggs.bucket(key, "terms", field=key,
                              size=self.RESULT_SIZE).metric(key,
                                                            "cardinality",
                                                            field=key)

        try:
            query = query.extra(
                size=self.RESULT_SIZE)  # do not need the raw data.
            sLog.debug("Query", query.to_dict())
            result = query.execute()
        except TransportError as e:
            return S_ERROR(e)

        values = []
        for bucket in result.aggregations[key].buckets:
            values += [bucket["key"]]
        del query
        sLog.debug("Nb of unique rows retrieved", len(values))
        return S_OK(values)

    def pingDB(self):
        """
        Try to connect to the database

        :return: S_OK(TRUE/FALSE)
        """
        connected = False
        try:
            connected = self.client.ping()
        except ConnectionError as e:
            sLog.error("Cannot connect to the db", repr(e))
        return S_OK(connected)

    @ifConnected
    def deleteByQuery(self, indexName, query):
        """
        Delete data by query (careful!)

        :param str indexName: the name of the index
        :param str query: the JSON-formatted query for which we want to issue the delete
        """
        try:
            self.client.delete_by_query(index=indexName, body=query)
        except Exception as inst:
            sLog.error("ERROR: Couldn't delete data")
            return S_ERROR(inst)
        return S_OK("Successfully deleted data from index %s" % indexName)

    @staticmethod
    def generateFullIndexName(indexName, period):
        """
        Given an index prefix we create the actual index name.

        :param str indexName: it is the name of the index
        :param str period: We can specify which kind of indexes will be created (day, week, month, year, null).
        :returns: string with full index name
        """

        # if the period is not correct, we use no-period indexes (same as "null").
        if period.lower() not in ["day", "week", "month", "year", "null"]:
            sLog.error("Period is not correct: ", period)
            return indexName
        elif period.lower() == "day":
            today = datetime.today().strftime("%Y-%m-%d")
            return "%s-%s" % (indexName, today)
        elif period.lower() == "week":
            week = datetime.today().isocalendar()[1]
            return "%s-%s" % (indexName, week)
        elif period.lower() == "month":
            month = datetime.today().strftime("%Y-%m")
            return "%s-%s" % (indexName, month)
        elif period.lower() == "year":
            year = datetime.today().strftime("%Y")
            return "%s-%s" % (indexName, year)
        elif period.lower() == "null":
            return indexName
예제 #7
0
class SearchApiClient:
    def __init__(self, host=settings.OPENSEARCH_HOST):

        protocol = settings.OPENSEARCH_PROTOCOL
        protocol_config = {}
        if protocol == "https":
            protocol_config = {
                "scheme": "https",
                "port": 443,
                "use_ssl": True,
                "verify_certs": settings.OPENSEARCH_VERIFY_CERTS,
            }

        if settings.IS_AWS:
            http_auth = ("supersurf", settings.OPENSEARCH_PASSWORD)
        else:
            http_auth = (None, None)

        self.client = OpenSearch([host],
                                 http_auth=http_auth,
                                 connection_class=RequestsHttpConnection,
                                 **protocol_config)
        self.index_nl = settings.OPENSEARCH_NL_INDEX
        self.index_en = settings.OPENSEARCH_EN_INDEX
        self.index_unk = settings.OPENSEARCH_UNK_INDEX
        self.languages = {"nl": self.index_nl, "en": self.index_en}

    @staticmethod
    def parse_search_result(search_result):
        """
        Parses the search result into the correct format that the frontend uses

        :param search_result: result from search
        :return result: list of results ready for frontend
        """
        hits = search_result.pop("hits")
        aggregations = search_result.get("aggregations", {})
        result = dict()
        result['recordcount'] = hits['total']['value']

        # Transform aggregations into drilldowns
        drilldowns = {}
        for aggregation_name, aggregation in aggregations.items():
            buckets = aggregation["filtered"][
                "buckets"] if "filtered" in aggregation else aggregation[
                    "buckets"]
            for bucket in buckets:
                drilldowns[f"{aggregation_name}-{bucket['key']}"] = bucket[
                    "doc_count"]
        result['drilldowns'] = drilldowns

        # Parse spelling suggestions
        did_you_mean = {}
        if 'suggest' in search_result:
            spelling_suggestion = search_result['suggest'][
                'did-you-mean-suggestion'][0]
            spelling_option = spelling_suggestion['options'][0] if len(
                spelling_suggestion['options']) else None
            if spelling_option is not None and spelling_option["score"] >= 0.01:
                did_you_mean = {
                    'original': spelling_suggestion['text'],
                    'suggestion': spelling_option['text']
                }
        result['did_you_mean'] = did_you_mean

        # Transform hits into records
        result['records'] = [
            SearchApiClient.parse_search_hit(hit) for hit in hits['hits']
        ]
        return result

    @staticmethod
    def parse_search_hit(hit, transform=True):
        """
        Parses the search hit into the format that is also used by the edurep endpoint.
        It's mostly just mapping the variables we need into the places that we expect them to be.
        :param hit: result from search
        :return record: parsed record
        """
        data = hit["_source"]
        serializer = SearchResultSerializer()
        # Basic mapping between field and data (excluding any method fields with a source of "*")
        field_mapping = {
            field.source: field_name if transform else field.source
            for field_name, field in serializer.fields.items()
            if field.source != "*"
        }
        record = {
            field_mapping[field]: value
            for field, value in data.items() if field in field_mapping
        }
        # Reformatting some fields if a relations field is desired
        if "relations" in field_mapping:
            publishers = [{
                "name": publisher
            } for publisher in data.get("publishers", [])]
            record["relations"] = {
                "authors":
                data.get("authors", []),
                "parties":
                data.get("parties", []) or publishers,
                "projects":
                data.get("projects", []),
                "keywords": [{
                    "label": keyword
                } for keyword in data.get("keywords", [])],
                "themes": [{
                    "label": theme
                } for theme in data.get("research_themes", [])],
                "parents":
                data.get("is_part_of", []),
                "children":
                data.get("has_parts", [])
            }
        # Calling methods on serializers to set data for method fields
        for field_name, field in serializer.fields.items():
            if field.source != "*":
                continue
            record[field_name] = getattr(serializer, field.method_name)(data)

        # Add highlight to the record
        if hit.get("highlight", 0):
            record["highlight"] = hit["highlight"]

        return record

    def autocomplete(self, query):
        """
        Use the suggest query to get typing hints during searching.

        :param query: the input from the user so far
        :return: a list of options matching the input query, sorted by length
        """
        # build the query for search engine
        query_dictionary = {
            'suggest': {
                "autocomplete": {
                    'text': query,
                    "completion": {
                        "field": "suggest_completion",
                        "size": 100
                    }
                }
            }
        }

        result = self.client.search(
            index=[self.index_nl, self.index_en, self.index_unk],
            body=query_dictionary)

        # extract the options from the search result, remove duplicates,
        # remove non-matching prefixes (engine will suggest things that don't match _exactly_)
        # and sort by length
        autocomplete = result['suggest']['autocomplete']
        options = autocomplete[0]['options']
        flat_options = list(
            set([
                item for option in options
                for item in option['_source']['suggest_completion']
            ]))
        options_with_prefix = [
            option for option in flat_options
            if option.lower().startswith(query.lower())
        ]
        options_with_prefix.sort(key=lambda option: len(option))
        return options_with_prefix

    def drilldowns(self, drilldown_names, search_text=None, filters=None):
        """
        This function is named drilldowns is because it's also named drilldowns in the original edurep search code.
        It passes on information to search, and returns the search without the records.
        This allows calculation of 'item counts' (i.e. how many results there are in through a certain filter)
        """
        search_results = self.search(search_text=search_text,
                                     filters=filters,
                                     drilldown_names=drilldown_names)
        search_results["records"] = []
        return search_results

    def search(self,
               search_text,
               drilldown_names=None,
               filters=None,
               ordering=None,
               page=1,
               page_size=5):
        """
        Build and send a query to search engine and parse it before returning.

        :param search_text: A list of strings to search for.
        :param drilldown_names: A list of the 'drilldowns' (filters) that are to be counted by engine.
        :param filters: The filters that are applied for this search.
        :param ordering: Sort the results by this ordering (or use default search ordering otherwise)
        :param page: The page index of the results
        :param page_size: How many items are loaded per page.
        :return:
        """

        start_record = page_size * (page - 1)
        body = {
            'query': {
                "bool": defaultdict(list)
            },
            'from': start_record,
            'size': page_size,
            'post_filter': {
                "bool": defaultdict(list)
            },
            'highlight': {
                'number_of_fragments': 1,
                'fragment_size': 120,
                'fields': {
                    'description': {},
                    'text': {}
                }
            }
        }

        if search_text:
            query_string = {
                "simple_query_string": {
                    "fields": SEARCH_FIELDS,
                    "query": search_text,
                    "default_operator": "and"
                }
            }
            body["query"]["bool"]["must"] += [query_string]
            body["query"]["bool"]["should"] = {
                "distance_feature": {
                    "field": "publisher_date",
                    "pivot": "90d",
                    "origin": "now",
                    "boost": 1.15
                }
            }
            body["suggest"] = {
                'did-you-mean-suggestion': {
                    'text': search_text,
                    'phrase': {
                        'field':
                        'suggest_phrase',
                        'size':
                        1,
                        'gram_size':
                        3,
                        'direct_generator': [{
                            'field': 'suggest_phrase',
                            'suggest_mode': 'always'
                        }],
                    },
                }
            }

        indices = self.parse_index_language(self, filters)

        if drilldown_names:
            body["aggs"] = self.parse_aggregations(drilldown_names, filters)

        filters = self.parse_filters(filters)
        if filters:
            body["post_filter"]["bool"]["must"] += filters

        if ordering:
            body["sort"] = [self.parse_ordering(ordering), "_score"]
        # make query and parse
        result = self.client.search(index=indices, body=body)
        return self.parse_search_result(result)

    def get_materials_by_id(self,
                            external_ids,
                            page=1,
                            page_size=10,
                            **kwargs):
        """
        Retrieve specific materials from search engine through their external id.

        :param external_ids: the id's of the materials to retrieve
        :param page: The page index of the results
        :param page_size: How many items are loaded per page.
        :return: a list of search results (like a regular search).
        """
        start_record = page_size * (page - 1)

        normalized_external_ids = []
        for external_id in external_ids:
            if not external_id.startswith("surf"):
                normalized_external_ids.append(external_id)
            else:
                external_id_parts = external_id.split(":")
                normalized_external_ids.append(external_id_parts[-1])

        result = self.client.search(
            index=[self.index_nl, self.index_en, self.index_unk],
            body={
                "query": {
                    "bool": {
                        "must": [{
                            "terms": {
                                "external_id": normalized_external_ids
                            }
                        }]
                    }
                },
                'from': start_record,
                'size': page_size,
            },
        )
        results = self.parse_search_result(result)
        materials = {
            material["external_id"]: material
            for material in results["records"]
        }
        records = []
        for external_id in normalized_external_ids:
            if external_id not in materials:
                continue
            records.append(materials[external_id])
        results["recordcount"] = len(records)
        results["records"] = records
        return results

    def stats(self):
        stats = self.client.count(
            index=",".join([self.index_nl, self.index_en, self.index_unk]))
        return stats.get("count", 0)

    def more_like_this(self, external_id, language):
        index = self.languages.get(language, self.index_unk)
        body = {
            "query": {
                "more_like_this": {
                    "fields": ["title", "description"],
                    "like": [{
                        "_index": index,
                        "_id": external_id
                    }],
                    "min_term_freq": 1,
                    "max_query_terms": 12
                }
            }
        }
        search_result = self.client.search(index=index, body=body)
        hits = search_result.pop("hits")
        result = dict()
        result["records_total"] = hits["total"]["value"]
        result["results"] = [
            SearchApiClient.parse_search_hit(hit, transform=False)
            for hit in hits["hits"]
        ]
        return result

    def author_suggestions(self, author_name):
        body = {
            "query": {
                "bool": {
                    "must": {
                        "multi_match": {
                            "fields": [
                                field for field in SEARCH_FIELDS
                                if "authors" not in field
                            ],
                            "query":
                            author_name,
                        },
                    },
                    "must_not": {
                        "match": {
                            "authors.name.folded": author_name
                        }
                    }
                }
            }
        }
        search_result = self.client.search(
            index=[self.index_nl, self.index_en, self.index_unk], body=body)
        hits = search_result.pop("hits")
        result = dict()
        result["records_total"] = hits["total"]["value"]
        result["results"] = [
            SearchApiClient.parse_search_hit(hit, transform=False)
            for hit in hits["hits"]
        ]
        return result

    @staticmethod
    def parse_filters(filters):
        """
        Parse filters from the frontend format into the search engine format.
        Not every filter is handled by search engine  in the same way so it's a lot of manual parsing.

        :param filters: the list of filters to be parsed
        :return: the filters in the format for a search query.
        """
        if not filters:
            return {}
        filter_items = []
        for filter_item in filters:
            # skip filter_items that are empty
            # and the language filter item (it's handled by telling search engine in what index to search).
            if not filter_item['items'] or 'language.keyword' in filter_item[
                    'external_id']:
                continue
            search_type = filter_item['external_id']
            # date range query
            if search_type == "publisher_date":
                lower_bound, upper_bound = filter_item["items"]
                if lower_bound is not None or upper_bound is not None:
                    filter_items.append({
                        "range": {
                            "publisher_date": {
                                "gte": lower_bound,
                                "lte": upper_bound
                            }
                        }
                    })
            # all other filter types are handled by just using terms with the 'translated' filter items
            else:
                filter_items.append(
                    {"terms": {
                        search_type: filter_item["items"]
                    }})
        return filter_items

    def parse_aggregations(self, aggregation_names, filters):
        """
        Parse the aggregations so search engine can count the items properly.

        :param aggregation_names: the names of the aggregations to
        :param filters: the filters for the query
        :return:
        """

        aggregation_items = {}
        for aggregation_name in aggregation_names:
            other_filters = []

            if filters:
                other_filters = list(
                    filter(lambda x: x['external_id'] != aggregation_name,
                           filters))
                other_filters = self.parse_filters(other_filters)

            search_type = aggregation_name

            if len(other_filters) > 0:
                # Filter the aggregation by the filters applied to other categories
                aggregation_items[aggregation_name] = {
                    "filter": {
                        "bool": {
                            "must": other_filters
                        }
                    },
                    "aggs": {
                        "filtered": {
                            "terms": {
                                "field": search_type,
                                "size": 2000,
                            }
                        }
                    },
                }
            else:
                aggregation_items[aggregation_name] = {
                    "terms": {
                        "field": search_type,
                        "size": 2000,
                    }
                }
        return aggregation_items

    @staticmethod
    def parse_ordering(ordering):
        """
        Parse the frontend ordering format ('asc', 'desc' or None) into the type that search engine expects.
        """
        order = "asc"
        if ordering.startswith("-"):
            order = "desc"
            ordering = ordering[1:]
        search_type = ordering
        return {search_type: {"order": order}}

    @staticmethod
    def parse_index_language(self, filters):
        """
        Select the index to search on based on language.
        """
        # if no language is selected, search on both.
        indices = [self.index_nl, self.index_en, self.index_unk]
        if not filters:
            return indices
        language_item = [
            filter_item for filter_item in filters
            if filter_item['external_id'] == 'language.keyword'
        ]
        if not language_item:
            return indices
        language_indices = [
            f"latest-{language}" for language in language_item[0]['items']
        ]
        return language_indices if len(language_indices) else indices
예제 #8
0
def test_es_basic_operations():
    """Run basic operations for testing purposes."""

    es = OpenSearch([{"host": "localhost", "port": 9200}])

    try:
        logging.debug("Deleting existing test data")
        es.delete(index="unit-test-index", doc_type="test", id=1)
    except exceptions.NotFoundError:
        pass

    logging.debug("Adding test data")
    r = es.index(
        index="unit-test-index",
        doc_type="test",
        id=1,
        body={
            "name": "Koira Koiruli Pöö",
            "height": "49",
            "mass": "10",
            "hair_color": "blond",
            "birth_year": "1999",
            "gender": "male",
        },
    )

    assert r["result"] == "created"

    es.indices.refresh(index="unit-test-index")
    r = es.get(index="unit-test-index", doc_type="test", id=1)
    assert r["_id"] == "1"

    s = es.search(index="unit-test-index",
                  body={"query": {
                      "match": {
                          "name": "cat"
                      }
                  }})
    hits = s["hits"]["total"]["value"]
    assert hits == 0

    s = es.search(index="unit-test-index", body={"query": {"match_all": {}}})
    logging.debug(s)
    hits = s["hits"]["total"]["value"]
    assert hits == 1

    s = es.search(index="unit-test-index",
                  body={"query": {
                      "match": {
                          "mass": "10"
                      }
                  }})
    logging.debug(s)
    hits = s["hits"]["total"]["value"]
    assert hits == 1

    s = es.search(index="unit-test-index",
                  body={"query": {
                      "match": {
                          "name": "Koiruli"
                      }
                  }})
    logging.debug(s)
    hits = s["hits"]["total"]["value"]
    assert hits == 1

    logging.debug("Deleting test data")
    es.delete(index="unit-test-index", doc_type="test", id=1)
예제 #9
0
class ElasticsearchSampler():
    """Elasticsearchサンプルクラス
    """
    def __init__(self):
        host = 'localhost'
        port = 9200
        auth = ('admin', 'admin')
        certs = 'cert/root-ca.pem'

        # Elasticsearchインタンスの作成
        self.es = OpenSearch(
            hosts=[{
                'host': host,
                'port': port
            }],
            http_auth=auth,
            use_ssl=True,
            verify_certs=True,
            ca_certs=certs,
            ssl_assert_hostname=False,
            ssl_show_warn=False,
        )

    def __del__(self):
        # ElasticsearchインスタンスのCLOSE
        self.es.close()
        print("close elasticsearch instance--------------------------")

    def search(self, idx: str, query: str):

        result = self.es.search(index=idx, body=query)

        print(f'{type(result)}')
        print('--[search]-------------------------------------------')
        pprint.pprint(result, sort_dicts=False)

    def dslusage(self, index):
        # 検索部分(Searchオブジェクト)
        s = Search(using=self.es, index=index)
        s = s.filter(
            'range', **{
                '@timestamp': {
                    'gte': '2020-10-01T00:00:00+09:00',
                    'lte': '2020-10-01T23:59:59+09:00',
                    'format': 'date_time_no_millis'
                }
            })
        s = s.extra(size=0)

        # 集計部分(Aggregationオブジェクト)
        aggs_port = A("terms", field="destination.port", size=20)

        # Aggregation オブジェクトを Search オブジェクトに紐付ける
        s.aggs.bucket("port-count", aggs_port)

        result = s.execute()

        # 結果抽出(Attrlist型)
        res_bucket = result.aggregations['port-count'].buckets
        print(f'==res_bucket : {res_bucket}')
        for item in res_bucket:
            print(f'port_count : {item}')