def transpile_query(input_query, sketch_id): """Entry point function for cypher transpilation. Given a Cypher query, this function parses it, inserts constraints (in match clauses, pattern expressions and pattern comprehensions) to restrict the query to only nodes and edges with given sketch_id property. Raises exceptions CypheParseError and InvalidQuery if the input does not parse or is not a read-only query or contains return clauses. Inserts return clauses generated from identifiers bound in match clauses. Inserts unwind clauses to handle the `timestamp` property of edges transparently (see docstring for unwind_timestamps for details). The generated query returns rows with 3 variables: - nodes - list of node ids - edges - list of edge ids - timestamps - list of lists of timestamps, can be zipped with edges """ # TODO: take with clauses into account when generating return clauses query, = pycypher.parse_query(input_query) delimiters = [] delimiters.append(0) for union_clause in query.find_nodes('CYPHER_AST_UNION'): delimiters.append(union_clause.start) delimiters.append(union_clause.end) delimiters.append(len(input_query)) parts = [ input_query[delimiters[i]:delimiters[i+1]] for i in range(len(delimiters) - 1) ] def transpile_part(part): if ' '.join(part.split()).strip() in ('UNION', 'UNION ALL'): return part.strip() part = restrict_query_to_sketch(part, sketch_id) part = append_return_clause(part) part = unwind_timestamps(part) return part.strip() return ' '.join(transpile_part(part) for part in parts)
def append_return_clause(input_query): """Append a return clause to a query so that it returns 3 list-valued columns: - nodes - list of ids of all match-bound nodes in the query - edges - list of ids of all match-bound edges in the query - timestamps - list of lists of timestamps, can be zipped with edges """ query, = pycypher.parse_query(input_query) nodes = [] rels = [] for match_clause in query.find_nodes('CYPHER_AST_MATCH'): if match_clause.get_pattern() is None: continue p = match_clause.get_pattern() for node_pattern in p.find_nodes('CYPHER_AST_NODE_PATTERN'): if node_pattern.get_identifier() is not None: nodes.append(node_pattern.get_identifier().get_name()) for rel_pattern in p.find_nodes('CYPHER_AST_REL_PATTERN'): if rel_pattern.get_identifier() is not None: if rel_pattern.get_varlength() is not None: continue rels.append(rel_pattern.get_identifier().get_name()) nodes = sorted(nodes) rels = sorted(rels) result = input_query.strip() + ' RETURN ' result += '[' + ', '.join('id(%s)' % node for node in nodes) + '] AS nodes' result += ', ' result += '[' + ', '.join('id(%s)' % rel for rel in rels) + '] AS edges' result += ', ' result += '[' + ', '.join('%s.timestamps' % rel for rel in rels) + ']' result += ' AS timestamps' result += ' LIMIT 10000' return result
def restrict_query_to_sketch(input_query, sketch_id): """Insert constraints in MATCH clauses, pattern expressions and pattern comprehensions so that the query can access only nodes and edges that have sketch_id property equal given sketch_id. Raises: InvalidQuery, pycypher.CypherParseError """ forbidden = [ 'CYPHER_AST_CALL', 'CYPHER_AST_START', 'CYPHER_AST_LOAD_CSV', 'CYPHER_AST_RETURN', 'CYPHER_AST_MERGE', 'CYPHER_AST_CREATE', 'CYPHER_AST_SET', 'CYPHER_AST_DELETE', 'CYPHER_AST_REMOVE', 'CYPHER_AST_SCHEMA_COMMAND', ] query, = pycypher.parse_query(input_query) for ast_type in forbidden: if list(query.find_nodes(ast_type)) != []: raise InvalidQuery('%s is not allowed.' % type) for prop_name in query.find_nodes('CYPHER_AST_PROP_NAME'): if prop_name.get_value() == 'sketch_id': raise InvalidQuery('Accessing sketch_id property is not allowed.') q = InsertableString(input_query) for node_pattern in query.find_nodes('CYPHER_AST_NODE_PATTERN'): properties = node_pattern.get_properties() if properties is not None: if not properties.get_keys(): q.insert_at(properties.start + 1, 'sketch_id: %d' % sketch_id) else: q.insert_at(properties.start + 1, 'sketch_id: %d, ' % sketch_id) else: q.insert_at(node_pattern.end - 1, '{sketch_id: %d}' % sketch_id) for rel_pattern in query.find_nodes('CYPHER_AST_REL_PATTERN'): properties = rel_pattern.get_properties() if properties is not None: if not properties.get_keys(): q.insert_at(properties.start + 1, 'sketch_id: %d' % sketch_id) else: q.insert_at(properties.start + 1, 'sketch_id: %d, ' % sketch_id) elif rel_pattern.end - rel_pattern.start == 2: q.insert_at(rel_pattern.end - 1, '[{sketch_id: %d}]' % sketch_id) elif rel_pattern.end - rel_pattern.start == 3: if rel_pattern.get_direction() == 'CYPHER_REL_OUTBOUND': pos = rel_pattern.end - 2 q.insert_at(pos, '[{sketch_id: %d}]' % sketch_id) else: pos = rel_pattern.end - 1 q.insert_at(pos, '[{sketch_id: %d}]' % sketch_id) else: if rel_pattern.get_direction() == 'CYPHER_REL_OUTBOUND': q.insert_at(rel_pattern.end - 3, '{sketch_id: %d}' % sketch_id) else: q.insert_at(rel_pattern.end - 2, '{sketch_id: %d}' % sketch_id) result = q.apply_insertions() try: parsed_result, = pycypher.parse_query(result) is_ok = query_is_restricted_to_sketch(parsed_result, sketch_id) except pycypher.CypherParseError: is_ok = False if not is_ok: raise InvalidQuery( 'Your query probably has spaces in relationship pattern or it' 'has other non-standard constructs which are not allowed.') return result
def unwind_timestamps(input_query): """Simplified description: For each bound edge e, replaces all `e.timestamp` with `e_timestamp` and insert `UNWIND e.timestamps AS e_timestamp` before. Also, in the return clause, replace `e.timestamps` with `collect(e_timestamp)` for those edges. Additional magic involving nulls is done to account for missing timestamps so that if an edge has `e.timestamps_incomplete = true`, this edge will be included in query results even if none of it's unwound timestamps matches the conditions in where clauses. For exact and up-to-date reference and examples, please see the unit tests. """ query, = pycypher.parse_query(input_query) q = InsertableString(input_query) unwound_rels = [] def get_position_before_where(match_clause): result = 0 for subclause in match_clause.children: # pylint: disable=protected-access if 'predicate' not in subclause._roles: result = max(result, subclause.end) return result def get_references_to_timestamp(rel, ast): for prop_access in ast.find_nodes('CYPHER_AST_PROPERTY_OPERATOR'): if prop_access.get_expression() is None: continue if prop_access.get_prop_name() is None: continue if prop_access.get_prop_name().get_value() != 'timestamp': continue expr = prop_access.get_expression() if not expr.instanceof('CYPHER_AST_IDENTIFIER'): continue if expr.get_name() != rel: continue yield prop_access def timestamp_of_rel_is_referenced_in(rel, ast): for _prop_access in get_references_to_timestamp(rel, ast): return True return False def decompose_predicate(pred): if pred is None: return [] if pred.instanceof('CYPHER_AST_BINARY_OPERATOR'): if pred.get_operator() == 'CYPHER_OP_AND': arg1 = pred.get_argument1() arg2 = pred.get_argument2() return decompose_predicate(arg1) + decompose_predicate(arg2) return [pred] def insert_unwind_clauses(): for match_clause in query.find_nodes('CYPHER_AST_MATCH'): if match_clause.get_pattern() is None: continue p = match_clause.get_pattern() rels_to_unwind = [] constraints_to_repeat_before_unwind = [] for rel_pattern in p.find_nodes('CYPHER_AST_REL_PATTERN'): if rel_pattern.get_identifier() is None: continue rel = rel_pattern.get_identifier().get_name() if rel in unwound_rels: continue if timestamp_of_rel_is_referenced_in(rel, query): rels_to_unwind.append(rel) if match_clause.get_predicate() is not None: pred = match_clause.get_predicate() for part in decompose_predicate(pred): is_safe = True for rel in rels_to_unwind: if timestamp_of_rel_is_referenced_in(rel, part): is_safe = False if is_safe: constraints_to_repeat_before_unwind.append(part) def unwind(rel): return ( 'UNWIND %s.timestamps + ' 'filter(a IN [null] WHERE %s.timestamps_incomplete) ' 'AS %s_timestamp' % (rel, rel, rel) ) if rels_to_unwind: to_insert = ' ' if constraints_to_repeat_before_unwind: to_insert += 'WHERE ' + ' AND '.join( input_query[c.start:c.end].strip() for c in constraints_to_repeat_before_unwind ) + ' ' to_insert += ' '.join( unwind(rel) for rel in rels_to_unwind) to_insert += ' WITH * ' pos = get_position_before_where(match_clause) q.insert_at(pos, to_insert) unwound_rels.extend(rels_to_unwind) def amend_where_clauses_to_account_for_missing_timestamps(): for clause in query.find_nodes('CYPHER_AST_QUERY_CLAUSE'): if clause.get_predicate() is None: continue pred = clause.get_predicate() rels_to_consider = [ rel for rel in unwound_rels if timestamp_of_rel_is_referenced_in(rel, pred) ] if rels_to_consider: q.insert_at(pred.start, ' coalesce((') condition = ' OR '.join( '%s_timestamp IS NULL' % rel for rel in rels_to_consider ) q.insert_at(pred.end, '), %s)' % condition) def replace_timestamp_property_references(): for rel in unwound_rels: for prop_access in get_references_to_timestamp(rel, query): q.replace_range( prop_access.start, prop_access.end - 1, '%s_timestamp' % rel ) def amend_return_clauses_for_unwound_timestamps(): for return_clause in query.find_nodes('CYPHER_AST_RETURN'): for prop_access in return_clause.find_nodes( 'CYPHER_AST_PROPERTY_OPERATOR'): if prop_access.get_prop_name() is None: continue if prop_access.get_prop_name().get_value() != 'timestamps': continue expr = prop_access.get_expression() if expr is None: continue if not expr.instanceof('CYPHER_AST_IDENTIFIER'): continue if expr.get_name() in unwound_rels: q.replace_range( prop_access.start, prop_access.end, 'collect(%s_timestamp)' % expr.get_name(), ) insert_unwind_clauses() amend_where_clauses_to_account_for_missing_timestamps() replace_timestamp_property_references() amend_return_clauses_for_unwound_timestamps() result = q.apply_insertions() return result