示例#1
0
def query_batch(
    client: GraphClient,
    batch_size: int,
    ttl_cutoff_ms: int,
    last_uid: Optional[str] = None,
) -> List[Dict[str, Union[Dict, str]]]:
    after = "" if last_uid is None else f", after: {last_uid}"
    paging = f"first: {batch_size}{after}"
    query = f"""
    {{
        q(func: le(last_index_time, {ttl_cutoff_ms}), {paging}) {{
            uid,
            expand(_all_) {{ uid }}
        }}
    }}
    """

    txn = client.txn()
    try:
        app.log.debug(f"retrieving batch: {query}")
        batch = txn.query(query)
        app.log.debug(f"retrieved batch: {batch.json}")
        return json.loads(batch.json)["q"]
    finally:
        txn.discard()
示例#2
0
def _upsert(client: GraphClient, node_dict: Dict[str, Any]) -> str:
    node_dict["uid"] = "_:blank-0"
    node_key = node_dict["node_key"]
    query = f"""
        {{
            q0(func: eq(node_key, "{node_key}"), first: 1) {{
                    uid,
                    dgraph.type
                    expand(_all_)
            }}
        }}
        """
    txn = client.txn(read_only=False)

    try:
        res = json.loads(txn.query(query).json)["q0"]
        new_uid = None
        if res:
            node_dict["uid"] = res[0]["uid"]
            new_uid = res[0]["uid"]

        mutation = node_dict

        mut_res = txn.mutate(set_obj=mutation, commit_now=True)
        new_uid = node_dict.get("uid") or mut_res.uids["blank-0"]
        return cast(str, new_uid)
    finally:
        txn.discard()
示例#3
0
def query_dgraph_type(client: GraphClient,
                      type_name: str) -> List[QueryPredicateResult]:
    query = f"""
        schema(type: {type_name}) {{ type }}
    """
    txn = client.txn(read_only=True)
    try:
        res = json.loads(txn.query(query).json)
    finally:
        txn.discard()

    if not res:
        return []
    if not res.get("types"):
        return []

    res = res["types"][0]["fields"]
    predicate_names = []
    for pred in res:
        predicate_names.append(pred["name"])

    predicate_metas = []
    for predicate_name in predicate_names:
        predicate_metas.append(query_dgraph_predicate(client, predicate_name))

    return predicate_metas
示例#4
0
def _upsert(client: GraphClient, node_dict: Dict[str, Any]) -> str:
    node_dict["uid"] = "_:blank-0"
    node_key = node_dict["node_key"]
    query = f"""
        {{
            q0(func: eq(node_key, "{node_key}"), first: 1) {{
                    uid,
                    dgraph.type
                    expand(_all_)
            }}
        }}
        """

    with client.txn_context(read_only=False) as txn:
        res = json.loads(txn.query(query).json)["q0"]
        new_uid = None
        if res:
            node_dict["uid"] = res[0]["uid"]
            new_uid = res[0]["uid"]

        mutation = node_dict

        m_res = txn.mutate(set_obj=mutation, commit_now=True)
        uids = m_res.uids

        if new_uid is None:
            new_uid = uids["blank-0"]
        return str(new_uid)
示例#5
0
    def query(self, graph_client: GraphClient, first: int) -> List[V]:
        var_alloc, query = gen_query(self, "q0", first=first)

        variables = {v: k for k, v in var_alloc.allocated.items()}
        txn = graph_client.txn(read_only=True)

        with graph_client.txn_context(read_only=True) as txn:
            try:
                qres = json.loads(txn.query(query, variables=variables).json)
            except Exception as e:
                raise QueryFailedException(query, variables) from e

        d = qres.get("q0")
        if d:
            return [
                self.associated_viewable().from_dict(node, graph_client) for node in d
            ]
        return []
示例#6
0
def create_edge(client: GraphClient, from_uid: str, edge_name: str,
                to_uid: str) -> None:
    if edge_name[0] == "~":
        mut = {"uid": to_uid, edge_name[1:]: {"uid": from_uid}}

    else:
        mut = {"uid": from_uid, edge_name: {"uid": to_uid}}

    with client.txn_context(read_only=False) as txn:
        txn.mutate(set_obj=mut, commit_now=True)
示例#7
0
def set_property(client: GraphClient, uid: str, prop_name: str,
                 prop_value: Any) -> None:
    LOGGER.debug(f"Setting property {prop_name} as {prop_value} for {uid}")
    txn = client.txn(read_only=False)

    try:
        mutation = {"uid": uid, prop_name: prop_value}

        txn.mutate(set_obj=mutation, commit_now=True)
    finally:
        txn.discard()
示例#8
0
    def execute_file(
        self,
        name: str,
        file: str,
        graph: SubgraphView,
        sender: Connection,
        msg_id: str,
        chunk_size: int,
    ) -> None:
        try:
            pool = ThreadPool(processes=4)

            exec(file, globals())
            client = GraphClient()

            analyzers = get_analyzer_objects(client)
            if not analyzers:
                self.logger.warning(f"Got no analyzers for file: {name}")

            self.logger.info(f"Executing analyzers: {[an for an in analyzers.keys()]}")

            for nodes in chunker([n for n in graph.node_iter()], chunk_size):
                self.logger.info(f"Querying {len(nodes)} nodes")

                def exec_analyzer(
                    nodes: List[BaseView], sender: Connection
                ) -> List[BaseView]:
                    try:
                        self.exec_analyzers(
                            client, file, msg_id, nodes, analyzers, sender
                        )

                        return nodes
                    except Exception as e:
                        self.logger.error(traceback.format_exc())
                        self.logger.error(
                            f"Execution of {name} failed with {e} {e.args}"
                        )
                        sender.send(ExecutionFailed())
                        raise

                pool.apply_async(exec_analyzer, args=(nodes, sender))

            pool.close()

            pool.join()

            sender.send(ExecutionComplete())

        except Exception as e:
            self.logger.error(traceback.format_exc())
            self.logger.error(f"Execution of {name} failed with {e} {e.args}")
            sender.send(ExecutionFailed())
            raise
示例#9
0
def delete_nodes(client: GraphClient, nodes: Iterator[str]) -> int:
    del_ = [{"uid": uid} for uid in nodes]

    txn = client.txn()
    try:
        mut = txn.create_mutation(del_obj=del_)
        app.log.debug(f"deleting nodes: {mut}")
        txn.mutate(mutation=mut, commit_now=True)
        app.log.debug(f"deleted nodes: {json.dumps(del_)}")
        return len(del_)
    finally:
        txn.discard()
示例#10
0
def query_dgraph_predicate(client: GraphClient,
                           predicate_name: str) -> QueryPredicateResult:
    query = f"""
        schema(pred: {predicate_name}) {{  }}
    """
    txn = client.txn(read_only=True)
    try:
        res = json.loads(txn.query(query).json)["schema"][0]
    finally:
        txn.discard()

    return res
示例#11
0
def set_score(client: GraphClient,
              uid: str,
              new_score: int,
              txn: Any = None) -> None:
    if not txn:
        txn = client.txn(read_only=False)

    try:
        mutation = {"uid": uid, "score": new_score}

        txn.mutate(set_obj=mutation, commit_now=True)
    finally:
        txn.discard()
示例#12
0
async def lambda_handler(s3_event: SQSMessageBody, context: Any) -> None:
    graph_client = GraphClient()
    s3 = S3ResourceFactory(boto3).from_env()
    metrics = create_metrics_client()

    for event in s3_event["Records"]:
        with metrics.time_to_process_event():
            try:
                _process_one_event(event, s3, graph_client, metrics)
            except:
                metrics.event_processed(status="failure")
                raise
            else:
                metrics.event_processed(status="success")
示例#13
0
def create_edge(client: GraphClient, from_uid: int, edge_name: str,
                to_uid: int) -> None:
    if edge_name[0] == "~":
        mut = {"uid": to_uid, edge_name[1:]: {"uid": from_uid}}

    else:
        mut = {"uid": from_uid, edge_name: {"uid": to_uid}}

    txn = client.txn(read_only=False)
    try:
        res = txn.mutate(set_obj=mut, commit_now=True)
        LOGGER.debug("edge mutation result is: {}".format(res))
    finally:
        txn.discard()
示例#14
0
def delete_edges(client: GraphClient, edges: Iterator[Tuple[str, str, str]]) -> int:
    del_ = [
        create_edge_obj(src_uid, predicate, dest_uid)
        for src_uid, predicate, dest_uid in edges
    ]

    txn = client.txn()
    try:
        mut = txn.create_mutation(del_obj=del_)
        app.log.debug(f"deleting edges: {mut}")
        txn.mutate(mutation=mut, commit_now=True)
        app.log.debug(f"deleted edges: {json.dumps(del_)}")
        return len(del_)
    finally:
        txn.discard()
示例#15
0
def add_reverse_edge_type(client: GraphClient, uid_type: UidType,
                          edge_name: str) -> None:
    LOGGER.debug(
        f"adding reverse edge type uid_type: {uid_type} edge_name: {edge_name}"
    )
    self_type = uid_type._inner_type.self_type()

    existing_predicates = query_dgraph_type(client, self_type)
    predicates = "\n\t\t".join(existing_predicates)

    # In case we've already deployed this plugin
    if edge_name in predicates:
        return

    predicates += f"\n\t\t<~{edge_name}>"

    type_str = f"""
    type {self_type} {{
        {predicates}
    }}\n
    """

    op = pydgraph.Operation(schema=type_str)
    client.alter(op)
示例#16
0
def execute_file(name: str, file: str, graph: SubgraphView, sender, msg_id):
    try:
        pool = ThreadPool(processes=4)

        exec(file, globals())
        client = GraphClient()

        analyzers = get_analyzer_objects(client)
        if not analyzers:
            LOGGER.warning(f"Got no analyzers for file: {name}")

        LOGGER.info(f"Executing analyzers: {[an for an in analyzers.keys()]}")

        chunk_size = 100

        if IS_RETRY == "True":
            chunk_size = 10

        for nodes in chunker([n for n in graph.node_iter()], chunk_size):
            LOGGER.info(f"Querying {len(nodes)} nodes")

            def exec_analyzer(nodes, sender):
                try:
                    exec_analyzers(client, file, msg_id, nodes, analyzers,
                                   sender)

                    return nodes
                except Exception as e:
                    LOGGER.error(traceback.format_exc())
                    LOGGER.error(
                        f"Execution of {name} failed with {e} {e.args}")
                    sender.send(ExecutionFailed())
                    raise

            exec_analyzer(nodes, sender)
            pool.apply_async(exec_analyzer, args=(nodes, sender))

        pool.close()

        pool.join()

        sender.send(ExecutionComplete())

    except Exception as e:
        LOGGER.error(traceback.format_exc())
        LOGGER.error(f"Execution of {name} failed with {e} {e.args}")
        sender.send(ExecutionFailed())
        raise
示例#17
0
    def get_or_create(gclient: GraphClient, lens_name: str, lens_type: str) -> LensView:
        with gclient.txn_context(read_only=False) as txn:
            query = """
            # lens get_or_create
            query res($a: string) 
            {
              res(func: eq(node_key, $a), first: 1) @cascade
               {
                 uid,
                 node_type: dgraph.type,
                 node_key,
               }
             }"""

            variables = {"$a": f"lens-{lens_type}{lens_name}"}

            res = txn.query(query, variables=variables)

            res = json.loads(res.json)["res"]
            new_uid = None
            if res:
                new_uid = res[0]["uid"]
            else:
                m_res = txn.mutate(
                    set_obj={
                        "lens_name": lens_name,
                        "lens_type": lens_type,
                        "node_key": "lens-" + lens_type + lens_name,
                        "dgraph.type": "Lens",
                        "score": 0,
                    },
                    commit_now=True,
                )
                uids = m_res.uids

                new_uid = new_uid or uids["blank-0"]

        self_lens_query = LensQuery().with_node_key(eq="lens-" + lens_type + lens_name)
        self_lens = self_lens_query.query_first(gclient)
        assert (
            self_lens
        ), f"Lens must exist, but couldn't query: {self_lens_query.debug_query()}"
        return self_lens
示例#18
0
    def get_or_create(gclient: GraphClient, lens_name: str,
                      lens_type: str) -> "LensView":
        eg_txn = gclient.txn(read_only=False)
        try:
            query = """
            query res($a: string)
            {
              res(func: eq(node_key, $a), first: 1) @cascade
               {
                 uid,
                 node_type: dgraph.type,
                 node_key,
               }
             }"""
            res = eg_txn.query(
                query, variables={"$a": "lens-" + lens_type + lens_name})

            res = json.loads(res.json)["res"]
            new_uid = None
            if res:
                new_uid = res[0]["uid"]
            else:
                m_res = eg_txn.mutate(
                    set_obj={
                        "lens": lens_name,
                        "lens_type": lens_type,
                        "node_key": "lens-" + lens_type + lens_name,
                        "dgraph.type": "Lens",
                        "score": 0,
                    },
                    commit_now=True,
                )
                uids = m_res.uids

                new_uid = new_uid or uids["blank-0"]
        finally:
            eg_txn.discard()

        self_lens = LensQuery().with_lens_name(
            eq=lens_name).query_first(gclient)
        assert self_lens, "Lens must exist"
        return self_lens
示例#19
0
def get_uid(client: GraphClient, node_key: str) -> str:
    with client.txn_context(read_only=True) as txn:
        query = """
            query res($a: string)
            {
              res(func: eq(node_key, $a), first: 1) @cascade
               {
                 uid,
               }
             }"""
        res = txn.query(query, variables={"$a": node_key})
        res = json.loads(res.json)

        if isinstance(res["res"], list):
            if res["res"]:
                return str(res["res"][0]["uid"])
            else:
                raise Exception(
                    f"get_uid failed for node_key: {node_key} {res}")
        else:
            return str(res["res"]["uid"])
示例#20
0
def query_dgraph_type(client: GraphClient, type_name: str) -> List[str]:
    query = f"""
        schema(type: {type_name}) {{ }}
    """
    LOGGER.debug(f"query: {query}")
    txn = client.txn(read_only=True)
    try:
        res = json.loads(txn.query(query).json)
        LOGGER.debug(f"res: {res}")
    finally:
        txn.discard()

    pred_names = []

    if "types" in res:
        for field in res["types"][0]["fields"]:
            pred_name = (f"<{field['name']}>"
                         if field["name"].startswith("~") else field["name"])
            pred_names.append(pred_name)

    return pred_names
示例#21
0
    def query_first(
        self,
        graph_client: GraphClient,
        contains_node_key: Optional[str] = None,
        best_effort=False,
    ) -> Optional[V]:
        if contains_node_key:
            var_alloc, query = gen_query_parameterized(self, "q0", contains_node_key, 0)
        else:
            var_alloc, query = gen_query(self, "q0", first=1)

        variables = {v: k for k, v in var_alloc.allocated.items()}

        with graph_client.txn_context(read_only=True, best_effort=best_effort) as txn:
            try:
                qres = json.loads(txn.query(query, variables=variables).json)
            except Exception as e:
                raise QueryFailedException(query, variables) from e

        d = qres.get("q0")
        if d:
            return self.associated_viewable().from_dict(d[0], graph_client)
        return None
示例#22
0
    def test_single_file_contains_key(
        self,
        node_key,
        file_path,
        file_extension,
        file_mime_type,
        file_size,
        file_version,
        file_description,
        file_product,
        file_company,
        file_directory,
        file_inode,
        file_hard_links,
        signed,
        signed_status,
        md5_hash,
        sha1_hash,
        sha256_hash,
    ) -> None:
        node_key = "test_single_file_contains_key" + str(node_key)
        signed = "true" if signed else "false"

        graph_client = GraphClient()

        get_or_create_file_node(
            graph_client,
            node_key,
            file_path=file_path,
            file_extension=file_extension,
            file_mime_type=file_mime_type,
            file_size=file_size,
            file_version=file_version,
            file_description=file_description,
            file_product=file_product,
            file_company=file_company,
            file_directory=file_directory,
            file_inode=file_inode,
            file_hard_links=file_hard_links,
            signed=signed,
            signed_status=signed_status,
            md5_hash=md5_hash,
            sha1_hash=sha1_hash,
            sha256_hash=sha256_hash,
        )

        queried_proc = FileQuery().query_first(graph_client,
                                               contains_node_key=node_key)

        assert node_key == queried_proc.node_key

        assert file_path == queried_proc.get_file_path() or ""
        assert file_extension == queried_proc.get_file_extension() or ""
        assert file_mime_type == queried_proc.get_file_mime_type() or ""
        assert file_version == queried_proc.get_file_version() or ""
        assert file_description == queried_proc.get_file_description() or ""
        assert file_product == queried_proc.get_file_product() or ""
        assert file_company == queried_proc.get_file_company() or ""
        assert file_directory == queried_proc.get_file_directory() or ""
        assert file_hard_links == queried_proc.get_file_hard_links() or ""
        assert signed == queried_proc.get_signed() or ""
        assert signed_status == queried_proc.get_signed_status() or ""
        assert md5_hash == queried_proc.get_md5_hash() or ""
        assert sha1_hash == queried_proc.get_sha1_hash() or ""
        assert sha256_hash == queried_proc.get_sha256_hash() or ""
        assert file_size == queried_proc.get_file_size()
        assert file_inode == queried_proc.get_file_inode()
示例#23
0
    async def handle_events(self, events: SQSMessageBody, context: Any) -> None:
        # Parse sns message
        self.logger.debug(f"handling events: {events} context: {context}")

        client = GraphClient()

        s3 = S3ResourceFactory(boto3).from_env()

        load_plugins(
            self.model_plugins_bucket,
            s3.meta.client,
            os.path.abspath(MODEL_PLUGINS_DIR),
        )

        for event in events["Records"]:
            data = parse_s3_event(s3, event)

            # FIXME: this code assumes inner_message is json
            envelope = OldEnvelope.deserialize(data)
            message = json.loads(envelope.inner_message)

            LOGGER.info(f'Executing Analyzer: {message["key"]}')

            with self.metric_reporter.histogram_ctx(
                "analyzer-executor.download_s3_file"
            ):
                analyzer = download_s3_file(
                    s3,
                    self.analyzers_bucket,
                    message["key"],
                ).decode("utf8")
            analyzer_name = message["key"].split("/")[-2]

            subgraph = SubgraphView.from_proto(client, bytes(message["subgraph"]))

            # TODO: Validate signature of S3 file
            LOGGER.info(f"event {event} {envelope.metadata}")
            rx: Connection
            tx: Connection
            rx, tx = Pipe(duplex=False)
            p = Process(
                target=self.execute_file,
                args=(analyzer_name, analyzer, subgraph, tx, "", self.chunk_size),
            )

            p.start()

            for exec_hit in self.poll_process(rx=rx, analyzer_name=analyzer_name):
                with self.metric_reporter.histogram_ctx(
                    "analyzer-executor.emit_event.ms",
                    (TagPair("analyzer_name", exec_hit.analyzer_name),),
                ):
                    emit_event(
                        self.analyzer_matched_subgraphs_bucket,
                        s3,
                        exec_hit,
                        envelope.metadata,
                    )
                self.update_msg_cache(analyzer, exec_hit.root_node_key, message["key"])
                self.update_hit_cache(analyzer_name, exec_hit.root_node_key)

            p.join()
示例#24
0
def set_schema(client: GraphClient, schema: str) -> None:
    op = pydgraph.Operation(schema=schema, run_in_background=True)
    LOGGER.info(f"Setting dgraph schema: {schema}")
    client.alter(op, timeout=SecsDuration(5))
    LOGGER.info(f"Completed setting dgraph schema")
示例#25
0
    def test_single_file_view_parity_eq(
        self,
        node_key,
        file_path,
        file_extension,
        file_mime_type,
        file_size,
        file_version,
        file_description,
        file_product,
        file_company,
        file_directory,
        file_inode,
        file_hard_links,
        signed,
        signed_status,
        md5_hash,
        sha1_hash,
        sha256_hash,
    ):
        node_key = "test_single_file_view_parity_eq" + str(node_key)
        signed = "true" if signed else "false"
        graph_client = GraphClient()

        get_or_create_file_node(
            graph_client,
            node_key,
            file_path=file_path,
            file_extension=file_extension,
            file_mime_type=file_mime_type,
            file_size=file_size,
            file_version=file_version,
            file_description=file_description,
            file_product=file_product,
            file_company=file_company,
            file_directory=file_directory,
            file_inode=file_inode,
            file_hard_links=file_hard_links,
            signed=signed,
            signed_status=signed_status,
            md5_hash=md5_hash,
            sha1_hash=sha1_hash,
            sha256_hash=sha256_hash,
        )

        queried_file = (FileQuery().with_node_key(eq=node_key).with_file_path(
            eq=file_path).with_file_extension(
                eq=file_extension).with_file_mime_type(
                    eq=file_mime_type).with_file_size(
                        eq=file_size).with_file_version(
                            eq=file_version).with_file_description(
                                eq=file_description).with_file_product(
                                    eq=file_product).
                        with_file_company(eq=file_company).with_file_directory(
                            eq=file_directory).with_file_inode(
                                eq=file_inode).with_file_hard_links(
                                    eq=file_hard_links).with_signed(eq=signed).
                        with_signed_status(eq=signed_status).with_md5_hash(
                            eq=md5_hash).with_sha1_hash(
                                eq=sha1_hash).with_sha256_hash(
                                    eq=sha256_hash).query_first(graph_client))

        assert node_key == queried_file.node_key

        assert file_path == queried_file.get_file_path()
        assert file_extension == queried_file.get_file_extension()
        assert file_mime_type == queried_file.get_file_mime_type()
        assert file_size == queried_file.get_file_size()
        assert file_version == queried_file.get_file_version()
        assert file_description == queried_file.get_file_description()
        assert file_product == queried_file.get_file_product()
        assert file_company == queried_file.get_file_company()
        assert file_directory == queried_file.get_file_directory()
        assert file_inode == queried_file.get_file_inode()
        assert file_hard_links == queried_file.get_file_hard_links()
        assert signed == queried_file.get_signed()
        assert signed_status == queried_file.get_signed_status()
        assert md5_hash == queried_file.get_md5_hash()
        assert sha1_hash == queried_file.get_sha1_hash()
        assert sha256_hash == queried_file.get_sha256_hash()
示例#26
0
def lambda_handler_fn(events: Any, context: Any) -> None:
    # Parse sns message
    LOGGER.debug(f"handling events: {events} context: {context}")

    client = GraphClient()

    s3 = get_s3_client()

    load_plugins(os.environ["BUCKET_PREFIX"], s3,
                 os.path.abspath(MODEL_PLUGINS_DIR))

    for event in events["Records"]:
        if not IS_LOCAL:
            event = json.loads(event["body"])["Records"][0]
        data = parse_s3_event(s3, event)

        message = json.loads(data)

        LOGGER.info(f'Executing Analyzer: {message["key"]}')
        analyzer = download_s3_file(
            s3, f"{os.environ['BUCKET_PREFIX']}-analyzers-bucket",
            message["key"])
        analyzer_name = message["key"].split("/")[-2]

        subgraph = SubgraphView.from_proto(client, bytes(message["subgraph"]))

        # TODO: Validate signature of S3 file
        LOGGER.info(f"event {event}")
        rx: Connection
        tx: Connection
        rx, tx = Pipe(duplex=False)
        p = Process(target=execute_file,
                    args=(analyzer_name, analyzer, subgraph, tx, ""))

        p.start()
        t = 0

        while True:
            p_res = rx.poll(timeout=5)
            if not p_res:
                t += 1
                LOGGER.info(
                    f"Polled {analyzer_name} for {t * 5} seconds without result"
                )
                continue
            result: Optional[Any] = rx.recv()

            if isinstance(result, ExecutionComplete):
                LOGGER.info("execution complete")
                break

            # emit any hits to an S3 bucket
            if isinstance(result, ExecutionHit):
                LOGGER.info(
                    f"emitting event for {analyzer_name} {result.analyzer_name} {result.root_node_key}"
                )
                emit_event(s3, result)
                update_msg_cache(analyzer, result.root_node_key,
                                 message["key"])
                update_hit_cache(analyzer_name, result.root_node_key)

            assert not isinstance(
                result, ExecutionFailed), f"Analyzer {analyzer_name} failed."

        p.join()
示例#27
0
 def __init__(self,
              query: BaseQuery,
              graph_client: GraphClient = None) -> None:
     self.query = query
     self.graph_client = graph_client or GraphClient()
示例#28
0
def set_schema(client: GraphClient, schema: str) -> None:
    op = pydgraph.Operation(schema=schema)
    client.alter(op)