def test_do_gremlin_status_and_cancel(self): query = "g.V().out().out().out().out()" query_res = {} gremlin_query_thread = threading.Thread( target=self.do_gremlin_query_save_results, args=( query, query_res, )) gremlin_query_thread.start() time.sleep(3) query_id = '' request_generator = create_request_generator(AuthModeEnum.DEFAULT) status_res = do_gremlin_status(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id, False) self.assertEqual(type(status_res), dict) self.assertTrue('acceptedQueryCount' in status_res) self.assertTrue('runningQueryCount' in status_res) self.assertTrue(status_res['runningQueryCount'] == 1) self.assertTrue('queries' in status_res) query_id = '' for q in status_res['queries']: if query in q['queryString']: query_id = q['queryId'] self.assertNotEqual(query_id, '') cancel_res = do_gremlin_cancel(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id) self.assertEqual(type(cancel_res), dict) self.assertTrue('status' in cancel_res) self.assertTrue('payload' in cancel_res) self.assertEqual('200 OK', cancel_res['status']) gremlin_query_thread.join() self.assertFalse('result' in query_res) self.assertTrue('error' in query_res) self.assertTrue('code' in query_res['error']) self.assertTrue('requestId' in query_res['error']) self.assertTrue('detailedMessage' in query_res['error']) self.assertTrue('TimeLimitExceededException' in query_res['error'])
def test_do_sparql_status_and_cancel(self): query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 .} ORDER BY DESC(?s) LIMIT 100" query_res = {} sparql_query_thread = threading.Thread( target=self.do_sparql_query_save_result, args=( query, query_res, )) sparql_query_thread.start() time.sleep(1) query_id = '' request_generator = create_request_generator( AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) status_res = do_sparql_status(self.host, self.port, self.ssl, self.request_generator, query_id) self.assertEqual(type(status_res), dict) self.assertTrue('acceptedQueryCount' in status_res) self.assertTrue('runningQueryCount' in status_res) self.assertTrue('queries' in status_res) time.sleep(1) query_id = '' for q in status_res['queries']: if query in q['queryString']: query_id = q['queryId'] self.assertNotEqual(query_id, '') cancel_res = do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, False) self.assertEqual(type(cancel_res), dict) self.assertTrue('acceptedQueryCount' in cancel_res) self.assertTrue('runningQueryCount' in cancel_res) self.assertTrue('queries' in cancel_res) sparql_query_thread.join() self.assertFalse('result' in query_res) self.assertTrue('error' in query_res) self.assertTrue('code' in query_res['error']) self.assertTrue('requestId' in query_res['error']) self.assertTrue('detailedMessage' in query_res['error']) self.assertEqual('CancelledByUserException', query_res['error']['code'])
def load_ids(self, line): credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type request_generator = create_request_generator( self.graph_notebook_config.auth_mode, credentials_provider_mode) res = get_loader_jobs(self.graph_notebook_config.host, self.graph_notebook_config.port, self.graph_notebook_config.ssl, request_generator) ids = [] if 'payload' in res and 'loadIds' in res['payload']: ids = res['payload']['loadIds'] labels = [widgets.Label(value=label_id) for label_id in ids] if not labels: labels = [widgets.Label(value="No load IDs found.")] vbox = widgets.VBox(labels) display(vbox)
def load_status(self, line, local_ns: dict = None): parser = argparse.ArgumentParser() parser.add_argument('load_id', default='', help='loader id to check status for') parser.add_argument('--store-to', type=str, default='') args = parser.parse_args(line.split()) credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type request_generator = create_request_generator( self.graph_notebook_config.auth_mode, credentials_provider_mode) res = get_load_status(self.graph_notebook_config.host, self.graph_notebook_config.port, self.graph_notebook_config.ssl, request_generator, args.load_id) print(json.dumps(res, indent=2)) if args.store_to != '' and local_ns is not None: local_ns[args.store_to] = res
def test_do_sparql_status_and_cancel_silently(self): query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 .} ORDER BY DESC(?s) LIMIT 100" query_res = {} sparql_query_thread = threading.Thread( target=self.do_sparql_query_save_result, args=( query, query_res, )) sparql_query_thread.start() time.sleep(1) query_id = '' request_generator = create_request_generator( AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) status_res = do_sparql_status(self.host, self.port, self.ssl, request_generator, query_id) self.assertEqual(type(status_res), dict) self.assertTrue('acceptedQueryCount' in status_res) self.assertTrue('runningQueryCount' in status_res) self.assertTrue('queries' in status_res) query_id = '' for q in status_res['queries']: if query in q['queryString']: query_id = q['queryId'] self.assertNotEqual(query_id, '') cancel_res = do_sparql_cancel(self.host, self.port, self.ssl, request_generator, query_id, True) self.assertEqual(type(cancel_res), dict) self.assertTrue('acceptedQueryCount' in cancel_res) self.assertTrue('runningQueryCount' in cancel_res) self.assertTrue('queries' in cancel_res) sparql_query_thread.join() self.assertEqual(type(query_res['result']), dict) self.assertTrue('s3' in query_res['result']['head']['vars']) self.assertTrue('p3' in query_res['result']['head']['vars']) self.assertTrue('o3' in query_res['result']['head']['vars']) self.assertEqual([], query_res['result']['results']['bindings'])
def cancel_load(self, line, local_ns: dict = None): parser = argparse.ArgumentParser() parser.add_argument('load_id', default='', help='loader id to check status for') parser.add_argument('--store-to', type=str, default='') args = parser.parse_args(line.split()) credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type request_generator = create_request_generator( self.graph_notebook_config.auth_mode, credentials_provider_mode) res = cancel_load(self.graph_notebook_config.host, self.graph_notebook_config.port, self.graph_notebook_config.ssl, request_generator, args.load_id) if res: print('Cancelled successfully.') else: print('Something went wrong cancelling bulk load job.') if args.store_to != '' and local_ns is not None: local_ns[args.store_to] = res
def load_ids(self, line, local_ns: dict = None): parser = argparse.ArgumentParser() parser.add_argument('--store-to', type=str, default='') args = parser.parse_args(line.split()) credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type request_generator = create_request_generator(self.graph_notebook_config.auth_mode, credentials_provider_mode) res = get_loader_jobs(self.graph_notebook_config.host, self.graph_notebook_config.port, self.graph_notebook_config.ssl, request_generator) ids = [] if 'payload' in res and 'loadIds' in res['payload']: ids = res['payload']['loadIds'] labels = [widgets.Label(value=label_id) for label_id in ids] if not labels: labels = [widgets.Label(value="No load IDs found.")] vbox = widgets.VBox(labels) display(vbox) if args.store_to != '' and local_ns is not None: local_ns[args.store_to] = res
def test_create_request_generator_sparql(self): mode = AuthModeEnum.DEFAULT command = 'sparql' rpg = create_request_generator(mode, command=command) self.assertEqual(SPARQLRequestGenerator, type(rpg))
def load(self, line): # since this can be a long-running task, freezing variables in the case # that a user alters them in another command. host = self.graph_notebook_config.host port = self.graph_notebook_config.port ssl = self.graph_notebook_config.ssl credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type request_generator = create_request_generator( self.graph_notebook_config.auth_mode, credentials_provider_mode) load_role_arn = self.graph_notebook_config.load_from_s3_arn region = self.graph_notebook_config.aws_region button = widgets.Button(description="Submit") output = widgets.Output() source = widgets.Text( value='s3://', placeholder='Type something', description='Source:', disabled=False, ) arn = widgets.Text(value=load_role_arn, placeholder='Type something', description='Load ARN:', disabled=False) source_format = widgets.Dropdown( options=['', 'csv', 'ntriples', 'nquads', 'rdfxml', 'turtle'], value= '', # blank so the user has to choose a format instead of risking the default one being incorrect. description='Format: ', disabled=False) region_box = widgets.Text(value=region, placeholder='us-east-1', description='AWS Region:', disabled=False) fail_on_error = widgets.Dropdown(options=['TRUE', 'FALSE'], value='TRUE', description='Fail on Failure: ', disabled=False) parallelism = widgets.Dropdown( options=['LOW', 'MEDIUM', 'HIGH', 'OVERSUBSCRIBE'], value='HIGH', description='Parallelism :', disabled=False) update_single_cardinality = widgets.Dropdown( options=['TRUE', 'FALSE'], value='FALSE', description='Update Single Cardinality:', disabled=False, ) source_hbox = widgets.HBox([source]) arn_hbox = widgets.HBox([arn]) source_format_hbox = widgets.HBox([source_format]) display(source_hbox, source_format_hbox, region_box, arn_hbox, fail_on_error, parallelism, update_single_cardinality, button, output) def on_button_clicked(b): source_hbox.children = (source, ) arn_hbox.children = (arn, ) source_format_hbox.children = (source_format, ) validated = True validation_label_style = DescriptionStyle(color='red') if not (source.value.startswith('s3://') and len(source.value) > 7 ) and not source.value.startswith('/'): validated = False source_validation_label = widgets.HTML( '<p style="color:red;">Source must be an s3 bucket or file path</p>' ) source_validation_label.style = validation_label_style source_hbox.children += (source_validation_label, ) if source_format.value == '': validated = False source_format_validation_label = widgets.HTML( '<p style="color:red;">Format cannot be blank.</p>') source_format_hbox.children += ( source_format_validation_label, ) if not arn.value.startswith('arn:aws') and source.value.startswith( "s3://" ): # only do this validation if we are using an s3 bucket. validated = False arn_validation_label = widgets.HTML( '<p style="color:red;">Load ARN must start with "arn:aws"</p>' ) arn_hbox.children += (arn_validation_label, ) if not validated: return source_exp = os.path.expandvars( source.value ) # replace any env variables in source.value with their values, can use $foo or ${foo}. Particularly useful for ${AWS_REGION} logger.info(f'using source_exp: {source_exp}') try: load_result = do_load(host, port, source_format.value, ssl, str(source_exp), region_box.value, arn.value, fail_on_error.value, parallelism.value, update_single_cardinality.value, request_generator) source_hbox.close() source_format_hbox.close() region_box.close() arn_hbox.close() fail_on_error.close() parallelism.close() update_single_cardinality.close() button.close() output.close() if 'status' not in load_result or load_result[ 'status'] != '200 OK': with output: print('Something went wrong.') print(load_result) logger.error(load_result) return load_id_label = widgets.Label( f'Load ID: {load_result["payload"]["loadId"]}') poll_interval = 5 interval_output = widgets.Output() job_status_output = widgets.Output() load_id_hbox = widgets.HBox([load_id_label]) status_hbox = widgets.HBox([interval_output]) vbox = widgets.VBox( [load_id_hbox, status_hbox, job_status_output]) display(vbox) last_poll_time = time.time() while True: time_elapsed = int(time.time() - last_poll_time) time_remaining = poll_interval - time_elapsed interval_output.clear_output() if time_elapsed > poll_interval: with interval_output: print('checking status...') job_status_output.clear_output() with job_status_output: display_html(HTML(loading_wheel_html)) try: interval_check_response = get_load_status( host, port, ssl, request_generator, load_result['payload']['loadId']) except Exception as e: logger.error(e) with job_status_output: print( 'Something went wrong updating job status. Ending.' ) return job_status_output.clear_output() with job_status_output: print( f'Overall Status: {interval_check_response["payload"]["overallStatus"]["status"]}' ) if interval_check_response["payload"][ "overallStatus"][ "status"] == 'LOAD_COMPLETED': interval_output.close() print('Done.') return last_poll_time = time.time() else: with interval_output: print( f'checking status in {time_remaining} seconds') time.sleep(1) except HTTPError as httpEx: output.clear_output() with output: print(httpEx.response.content.decode('utf-8')) button.on_click(on_button_clicked)
def db_reset(self, line): host = self.graph_notebook_config.host port = self.graph_notebook_config.port ssl = self.graph_notebook_config.ssl logger.info(f'calling system endpoint {host}') parser = argparse.ArgumentParser() parser.add_argument('-g', '--generate-token', action='store_true', help='generate token for database reset') parser.add_argument('-t', '--token', nargs=1, default='', help='perform database reset with given token') parser.add_argument('-y', '--yes', action='store_true', help='skip the prompt and perform database reset') args = parser.parse_args(line.split()) generate_token = args.generate_token skip_prompt = args.yes request_generator = create_request_generator( self.graph_notebook_config.auth_mode, self.graph_notebook_config.iam_credentials_provider_type) logger.info( f'used credentials_provider_mode={self.graph_notebook_config.iam_credentials_provider_type.name} and auth_mode={self.graph_notebook_config.auth_mode.name} to make system request' ) if generate_token is False and args.token == '': if skip_prompt: res = initiate_database_reset(host, port, ssl, request_generator) token = res['payload']['token'] res = perform_database_reset(token, host, port, ssl, request_generator) logger.info(f'got the response {res}') return res output = widgets.Output() source = 'Are you sure you want to delete all the data in your cluster?' label = widgets.Label(source) text_hbox = widgets.HBox([label]) check_box = widgets.Checkbox( value=False, disabled=False, indent=False, description= 'I acknowledge that upon deletion the cluster data will no longer be available.', layout=widgets.Layout(width='600px', margin='5px 5px 5px 5px')) button_delete = widgets.Button(description="Delete") button_cancel = widgets.Button(description="Cancel") button_hbox = widgets.HBox([button_delete, button_cancel]) display(text_hbox, check_box, button_hbox, output) def on_button_delete_clicked(b): result = initiate_database_reset(host, port, ssl, request_generator) text_hbox.close() check_box.close() button_delete.close() button_cancel.close() button_hbox.close() if not check_box.value: with output: print('Checkbox is not checked.') return token = result['payload']['token'] if token == "": with output: print('Failed to get token.') print(result) return result = perform_database_reset(token, host, port, ssl, request_generator) if 'status' not in result or result['status'] != '200 OK': with output: print( 'Database reset failed, please try the operation again or reboot the cluster.' ) print(result) logger.error(result) return retry = 10 poll_interval = 5 interval_output = widgets.Output() job_status_output = widgets.Output() status_hbox = widgets.HBox([interval_output]) vbox = widgets.VBox([status_hbox, job_status_output]) display(vbox) last_poll_time = time.time() while retry > 0: time_elapsed = int(time.time() - last_poll_time) time_remaining = poll_interval - time_elapsed interval_output.clear_output() if time_elapsed > poll_interval: with interval_output: print('checking status...') job_status_output.clear_output() with job_status_output: display_html(HTML(loading_wheel_html)) try: retry -= 1 interval_check_response = get_status( host, port, ssl, request_generator) except Exception as e: # Exception is expected when database is resetting, continue waiting with job_status_output: last_poll_time = time.time() time.sleep(1) continue job_status_output.clear_output() with job_status_output: if interval_check_response["status"] == 'healthy': interval_output.close() print('Database has been reset.') return last_poll_time = time.time() else: with interval_output: print( f'checking status in {time_remaining} seconds') time.sleep(1) with output: print(result) if interval_check_response["status"] != 'healthy': print( "Could not retrieve the status of the reset operation within the allotted time. " "If the database is not healthy after 1 min, please try the operation again or " "reboot the cluster.") def on_button_cancel_clicked(b): text_hbox.close() check_box.close() button_delete.close() button_cancel.close() button_hbox.close() with output: print('Database reset operation has been canceled.') button_delete.on_click(on_button_delete_clicked) button_cancel.on_click(on_button_cancel_clicked) return elif generate_token: res = initiate_database_reset(host, port, ssl, request_generator) else: # args.token is an array of a single string, e.g., args.token=['ade-23-c23'], use index 0 to take the string res = perform_database_reset(args.token[0], host, port, ssl, request_generator) logger.info(f'got the response {res}') return res
def gremlin(self, line, cell): parser = argparse.ArgumentParser() parser.add_argument( 'query_mode', nargs='?', default='query', help='query mode (default=query) [query|explain|profile]') parser.add_argument('-p', '--path-pattern', default='', help='path pattern') args = parser.parse_args(line.split()) mode = str_to_query_mode(args.query_mode) tab = widgets.Tab() if mode == QueryMode.EXPLAIN: request_generator = create_request_generator( self.graph_notebook_config.auth_mode, self.graph_notebook_config.iam_credentials_provider_type) raw_html = gremlin_explain(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, self.graph_notebook_config.ssl, request_generator) explain_output = widgets.Output(layout=DEFAULT_LAYOUT) with explain_output: display(HTML(raw_html)) tab.children = [explain_output] tab.set_title(0, 'Explain') display(tab) elif mode == QueryMode.PROFILE: request_generator = create_request_generator( self.graph_notebook_config.auth_mode, self.graph_notebook_config.iam_credentials_provider_type) raw_html = gremlin_profile(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, self.graph_notebook_config.ssl, request_generator) profile_output = widgets.Output(layout=DEFAULT_LAYOUT) with profile_output: display(HTML(raw_html)) tab.children = [profile_output] tab.set_title(0, 'Profile') display(tab) else: client_provider = create_client_provider( self.graph_notebook_config.auth_mode, self.graph_notebook_config.iam_credentials_provider_type) res = do_gremlin_query(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, self.graph_notebook_config.ssl, client_provider) children = [] titles = [] table_output = widgets.Output(layout=DEFAULT_LAYOUT) titles.append('Console') children.append(table_output) try: gn = GremlinNetwork() if args.path_pattern == '': gn.add_results(res) else: pattern = parse_pattern_list_str(args.path_pattern) gn.add_results_with_pattern(res, pattern) logger.debug(f'number of nodes is {len(gn.graph.nodes)}') if len(gn.graph.nodes) > 0: f = Force(network=gn, options=self.graph_notebook_vis_options) titles.append('Graph') children.append(f) logger.debug('added gremlin network to tabs') except ValueError as value_error: logger.debug( f'unable to create gremlin network from result. Skipping from result set: {value_error}' ) tab.children = children for i in range(len(titles)): tab.set_title(i, titles[i]) display(tab) table_id = f"table-{str(uuid.uuid4()).replace('-', '')[:8]}" table_html = gremlin_table_template.render(guid=table_id, results=res) with table_output: display(HTML(table_html))
def sparql(self, line='', cell=''): request_generator = create_request_generator( self.graph_notebook_config.auth_mode, self.graph_notebook_config.iam_credentials_provider_type, command='sparql') if line != '': mode = str_to_query_mode(line) else: mode = self.mode tab = widgets.Tab() logger.debug(f'using mode={mode}') if mode == QueryMode.EXPLAIN: html_raw = sparql_explain(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, self.graph_notebook_config.ssl, request_generator) explain_output = widgets.Output(layout=DEFAULT_LAYOUT) with explain_output: display(HTML(html_raw)) tab.children = [explain_output] tab.set_title(0, 'Explain') display(tab) else: query_type = get_query_type(cell) headers = {} if query_type not in [ 'SELECT', 'CONSTRUCT', 'DESCRIBE' ] else { 'Accept': 'application/sparql-results+json' } res = do_sparql_query(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, self.graph_notebook_config.ssl, request_generator, headers) titles = [] children = [] display(tab) table_output = widgets.Output(layout=DEFAULT_LAYOUT) # Assign an empty value so we can always display to table output. # We will only add it as a tab if the type of query allows it. # Because of this, the table_output will only be displayed on the DOM if the query was of type SELECT. table_html = "" query_type = get_query_type(cell) if query_type in ['SELECT', 'CONSTRUCT', 'DESCRIBE']: logger.debug('creating sparql network...') # some issues with displaying a datatable when not wrapped in an hbox and displayed last hbox = widgets.HBox([table_output], layout=DEFAULT_LAYOUT) titles.append('Table') children.append(hbox) expand_all = line == '--expand-all' sn = SPARQLNetwork(expand_all=expand_all) sn.extract_prefix_declarations_from_query(cell) try: sn.add_results(res) except ValueError as value_error: logger.debug(value_error) logger.debug(f'number of nodes is {len(sn.graph.nodes)}') if len(sn.graph.nodes) > 0: f = Force(network=sn, options=self.graph_notebook_vis_options) titles.append('Graph') children.append(f) logger.debug('added sparql network to tabs') rows_and_columns = get_rows_and_columns(res) if rows_and_columns is not None: table_id = f"table-{str(uuid.uuid4())[:8]}" table_html = sparql_table_template.render( columns=rows_and_columns['columns'], rows=rows_and_columns['rows'], guid=table_id) # Handling CONSTRUCT and DESCRIBE on their own because we want to maintain the previous result pattern # of showing a tsv with each line being a result binding in addition to new ones. if query_type == 'CONSTRUCT' or query_type == 'DESCRIBE': lines = [] for b in res['results']['bindings']: lines.append( f'{b["subject"]["value"]}\t{b["predicate"]["value"]}\t{b["object"]["value"]}' ) raw_output = widgets.Output(layout=DEFAULT_LAYOUT) with raw_output: html = sparql_construct_template.render(lines=lines) display(HTML(html)) children.append(raw_output) titles.append('Raw') json_output = widgets.Output(layout=DEFAULT_LAYOUT) with json_output: print(json.dumps(res, indent=2)) children.append(json_output) titles.append('JSON') tab.children = children for i in range(len(titles)): tab.set_title(i, titles[i]) with table_output: display(HTML(table_html))
def gremlin(self, line, cell, local_ns: dict = None): parser = argparse.ArgumentParser() parser.add_argument('query_mode', nargs='?', default='query', help='query mode (default=query) [query|explain|profile]') parser.add_argument('-p', '--path-pattern', default='', help='path pattern') parser.add_argument('-g', '--group-by', default='T.label', help='Property used to group nodes (e.g. code, T.region) default is T.label') parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') parser.add_argument('--ignore-groups', action='store_true', help="Ignore all grouping options") args = parser.parse_args(line.split()) mode = str_to_query_mode(args.query_mode) logger.debug(f'Arguments {args}') tab = widgets.Tab() if mode == QueryMode.EXPLAIN: request_generator = create_request_generator(self.graph_notebook_config.auth_mode, self.graph_notebook_config.iam_credentials_provider_type) query_res = do_gremlin_explain(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, self.graph_notebook_config.ssl, request_generator) if 'explain' in query_res: html = pre_container_template.render(content=query_res['explain']) else: html = pre_container_template.render(content='No explain found') explain_output = widgets.Output(layout=DEFAULT_LAYOUT) with explain_output: display(HTML(html)) tab.children = [explain_output] tab.set_title(0, 'Explain') display(tab) elif mode == QueryMode.PROFILE: request_generator = create_request_generator(self.graph_notebook_config.auth_mode, self.graph_notebook_config.iam_credentials_provider_type) query_res = do_gremlin_profile(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, self.graph_notebook_config.ssl, request_generator) if 'profile' in query_res: html = pre_container_template.render(content=query_res['profile']) else: html = pre_container_template.render(content='No profile found') profile_output = widgets.Output(layout=DEFAULT_LAYOUT) with profile_output: display(HTML(html)) tab.children = [profile_output] tab.set_title(0, 'Profile') display(tab) else: client_provider = create_client_provider(self.graph_notebook_config.auth_mode, self.graph_notebook_config.iam_credentials_provider_type) query_res = do_gremlin_query(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, self.graph_notebook_config.ssl, client_provider) children = [] titles = [] table_output = widgets.Output(layout=DEFAULT_LAYOUT) titles.append('Console') children.append(table_output) try: logger.debug(f'groupby: {args.group_by}') if args.ignore_groups: gn = GremlinNetwork() else: gn = GremlinNetwork(group_by_property=args.group_by) if args.path_pattern == '': gn.add_results(query_res) else: pattern = parse_pattern_list_str(args.path_pattern) gn.add_results_with_pattern(query_res, pattern) logger.debug(f'number of nodes is {len(gn.graph.nodes)}') if len(gn.graph.nodes) > 0: f = Force(network=gn, options=self.graph_notebook_vis_options) titles.append('Graph') children.append(f) logger.debug('added gremlin network to tabs') except ValueError as value_error: logger.debug(f'unable to create gremlin network from result. Skipping from result set: {value_error}') tab.children = children for i in range(len(titles)): tab.set_title(i, titles[i]) display(tab) table_id = f"table-{str(uuid.uuid4()).replace('-', '')[:8]}" table_html = gremlin_table_template.render(guid=table_id, results=query_res) with table_output: display(HTML(table_html)) store_to_ns(args.store_to, query_res, local_ns)
def test_do_status_with_iam_credentials(self): request_generator = create_request_generator( AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) status = get_status(self.host, self.port, self.ssl, request_generator) self.assertEqual(status['status'], 'healthy')
def test_do_sparql_cancel_non_str_query_id(self): with self.assertRaises(ValueError): query_id = 42 request_generator = create_request_generator(AuthModeEnum.DEFAULT) do_sparql_cancel(query_id, False, self.host, self.port, self.ssl, request_generator)
def on_button_clicked(b=None): submit_button.close() language_dropdown.disabled = True data_set_drop_down.disabled = True language = language_dropdown.value.lower() data_set = data_set_drop_down.value.lower() with output: print(f'Loading data set {data_set} with language {language}') queries = get_queries(language, data_set) if len(queries) < 1: with output: print('Did not find any queries for the given dataset') return load_index = 1 # start at 1 to have a non-empty progress bar progress = widgets.IntProgress( value=load_index, min=0, max=len(queries) + 1, # len + 1 so we can start at index 1 orientation='horizontal', bar_style='info', description='Loading:') with progress_output: display(progress) for q in queries: with output: print(f'{progress.value}/{len(queries)}:\t{q["name"]}') # Just like with the load command, seed is long-running # as such, we want to obtain the values of host, port, etc. in case they # change during execution. host = self.graph_notebook_config.host port = self.graph_notebook_config.port auth_mode = self.graph_notebook_config.auth_mode ssl = self.graph_notebook_config.ssl if language == 'gremlin': client_provider = create_client_provider( auth_mode, self.graph_notebook_config. iam_credentials_provider_type) # IMPORTANT: We treat each line as its own query! for line in q['content'].splitlines(): try: do_gremlin_query(line, host, port, ssl, client_provider) except GremlinServerError as gremlinEx: try: error = json.loads( gremlinEx.args[0] [5:]) # remove the leading error code. content = json.dumps(error, indent=2) except Exception: content = {'error': gremlinEx} with output: print(content) progress.close() return except Exception as e: content = {'error': e} with output: print(content) progress.close() return else: request_generator = create_request_generator( auth_mode, self.graph_notebook_config. iam_credentials_provider_type) try: do_sparql_query(q['content'], host, port, ssl, request_generator) except HTTPError as httpEx: # attempt to turn response into json try: error = json.loads( httpEx.response.content.decode('utf-8')) content = json.dumps(error, indent=2) except Exception: content = {'error': httpEx} with output: print(content) progress.close() return except Exception as ex: content = {'error': str(ex)} with output: print(content) progress.close() return progress.value += 1 progress.close() with output: print('Done.') return
def test_do_gremlin_cancel_non_str_query_id(self): with self.assertRaises(ValueError): query_id = 42 request_generator = create_request_generator(AuthModeEnum.DEFAULT) do_gremlin_cancel(self.host, self.port, self.ssl, self.auth_mode, request_generator, query_id)
def load(self, line='', local_ns: dict = None): parser = argparse.ArgumentParser() parser.add_argument('-s', '--source', default='s3://') parser.add_argument( '-l', '--loader-arn', default=self.graph_notebook_config.load_from_s3_arn) parser.add_argument('-f', '--format', choices=LOADER_FORMAT_CHOICES, default='') parser.add_argument('-p', '--parallelism', choices=PARALLELISM_OPTIONS, default=PARALLELISM_HIGH) parser.add_argument('-r', '--region', default=self.graph_notebook_config.aws_region) parser.add_argument('--fail-on-failure', action='store_true', default=False) parser.add_argument('--update-single-cardinality', action='store_true', default=True) parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') parser.add_argument('--run', action='store_true', default=False) args = parser.parse_args(line.split()) # since this can be a long-running task, freezing variables in the case # that a user alters them in another command. host = self.graph_notebook_config.host port = self.graph_notebook_config.port ssl = self.graph_notebook_config.ssl credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type request_generator = create_request_generator( self.graph_notebook_config.auth_mode, credentials_provider_mode) region = self.graph_notebook_config.aws_region button = widgets.Button(description="Submit") output = widgets.Output() source = widgets.Text( value=args.source, placeholder='Type something', description='Source:', disabled=False, ) arn = widgets.Text(value=args.loader_arn, placeholder='Type something', description='Load ARN:', disabled=False) source_format = widgets.Dropdown(options=LOADER_FORMAT_CHOICES, value=args.format, description='Format: ', disabled=False) region_box = widgets.Text(value=region, placeholder=args.region, description='AWS Region:', disabled=False) fail_on_error = widgets.Dropdown(options=['TRUE', 'FALSE'], value=str( args.fail_on_failure).upper(), description='Fail on Failure: ', disabled=False) parallelism = widgets.Dropdown(options=PARALLELISM_OPTIONS, value=args.parallelism, description='Parallelism :', disabled=False) update_single_cardinality = widgets.Dropdown( options=['TRUE', 'FALSE'], value=str(args.update_single_cardinality).upper(), description='Update Single Cardinality:', disabled=False, ) source_hbox = widgets.HBox([source]) arn_hbox = widgets.HBox([arn]) source_format_hbox = widgets.HBox([source_format]) display(source_hbox, source_format_hbox, region_box, arn_hbox, fail_on_error, parallelism, update_single_cardinality, button, output) def on_button_clicked(b): source_hbox.children = (source, ) arn_hbox.children = (arn, ) source_format_hbox.children = (source_format, ) validated = True validation_label_style = DescriptionStyle(color='red') if not (source.value.startswith('s3://') and len(source.value) > 7 ) and not source.value.startswith('/'): validated = False source_validation_label = widgets.HTML( '<p style="color:red;">Source must be an s3 bucket or file path</p>' ) source_validation_label.style = validation_label_style source_hbox.children += (source_validation_label, ) if source_format.value == '': validated = False source_format_validation_label = widgets.HTML( '<p style="color:red;">Format cannot be blank.</p>') source_format_hbox.children += ( source_format_validation_label, ) if not arn.value.startswith('arn:aws') and source.value.startswith( "s3://" ): # only do this validation if we are using an s3 bucket. validated = False arn_validation_label = widgets.HTML( '<p style="color:red;">Load ARN must start with "arn:aws"</p>' ) arn_hbox.children += (arn_validation_label, ) if not validated: return source_exp = os.path.expandvars( source.value ) # replace any env variables in source.value with their values, can use $foo or ${foo}. Particularly useful for ${AWS_REGION} logger.info(f'using source_exp: {source_exp}') try: load_result = do_load(host, port, source_format.value, ssl, str(source_exp), region_box.value, arn.value, fail_on_error.value, parallelism.value, update_single_cardinality.value, request_generator) store_to_ns(args.store_to, load_result, local_ns) source_hbox.close() source_format_hbox.close() region_box.close() arn_hbox.close() fail_on_error.close() parallelism.close() update_single_cardinality.close() button.close() output.close() if 'status' not in load_result or load_result[ 'status'] != '200 OK': with output: print('Something went wrong.') print(load_result) logger.error(load_result) return load_id_label = widgets.Label( f'Load ID: {load_result["payload"]["loadId"]}') poll_interval = 5 interval_output = widgets.Output() job_status_output = widgets.Output() load_id_hbox = widgets.HBox([load_id_label]) status_hbox = widgets.HBox([interval_output]) vbox = widgets.VBox( [load_id_hbox, status_hbox, job_status_output]) display(vbox) last_poll_time = time.time() while True: time_elapsed = int(time.time() - last_poll_time) time_remaining = poll_interval - time_elapsed interval_output.clear_output() if time_elapsed > poll_interval: with interval_output: print('checking status...') job_status_output.clear_output() with job_status_output: display_html(HTML(loading_wheel_html)) try: interval_check_response = get_load_status( host, port, ssl, request_generator, load_result['payload']['loadId']) except Exception as e: logger.error(e) with job_status_output: print( 'Something went wrong updating job status. Ending.' ) return job_status_output.clear_output() with job_status_output: print( f'Overall Status: {interval_check_response["payload"]["overallStatus"]["status"]}' ) if interval_check_response["payload"][ "overallStatus"][ "status"] in FINAL_LOAD_STATUSES: interval_output.close() print('Done.') return last_poll_time = time.time() else: with interval_output: print( f'checking status in {time_remaining} seconds') time.sleep(1) except HTTPError as httpEx: output.clear_output() with output: print(httpEx.response.content.decode('utf-8')) button.on_click(on_button_clicked) if args.run: on_button_clicked(None)
def setUp(self) -> None: request_generator = create_request_generator(AuthModeEnum.DEFAULT) res = do_sparql_status(self.host, self.port, self.ssl, request_generator) for q in res['queries']: do_sparql_cancel(self.host, self.port, self.ssl, request_generator, q['queryId'], False)
def test_create_request_generator_default(self): mode = AuthModeEnum.DEFAULT rpg = create_request_generator(mode) self.assertEqual(DefaultRequestGenerator, type(rpg))
def test_do_gremlin_cancel_non_str_query_id(self): with self.assertRaises(ValueError): query_id = 42 request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) do_gremlin_cancel(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id)
def test_create_request_generator_iam_env(self): mode = AuthModeEnum.IAM rpg = create_request_generator(mode, IAMAuthCredentialsProvider.ENV) self.assertEqual(IamRequestGenerator, type(rpg)) self.assertEqual(EnvCredentialsProvider, type(rpg.credentials_provider))
def test_do_db_reset_initiate_with_iam_credentials(self): request_generator = create_request_generator( AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV) result = initiate_database_reset(self.host, self.port, self.ssl, request_generator) self.assertNotEqual(result['payload']['token'], '')
def sparql(self, line='', cell='', local_ns: dict = None): parser = argparse.ArgumentParser() parser.add_argument('query_mode', nargs='?', default='query', help='query mode (default=query) [query|explain]') parser.add_argument('--endpoint-prefix', '-e', default='', help='prefix path to sparql endpoint. For example, if "foo/bar" were specified, the endpoint called would be /foo/bar/sparql') parser.add_argument('--expand-all', action='store_true') request_generator = create_request_generator(self.graph_notebook_config.auth_mode, self.graph_notebook_config.iam_credentials_provider_type, command='sparql') parser.add_argument('--store-to', type=str, default='', help='store query result to this variable') args = parser.parse_args(line.split()) mode = str_to_query_mode(args.query_mode) endpoint_prefix = args.endpoint_prefix if args.endpoint_prefix != '' else self.graph_notebook_config.sparql.endpoint_prefix tab = widgets.Tab() logger.debug(f'using mode={mode}') if mode == QueryMode.EXPLAIN: res = do_sparql_explain(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, self.graph_notebook_config.ssl, request_generator, path_prefix=endpoint_prefix) store_to_ns(args.store_to, res, local_ns) if 'error' in res: html = error_template.render(error=json.dumps(res['error'], indent=2)) else: html = sparql_explain_template.render(table=res) explain_output = widgets.Output(layout=DEFAULT_LAYOUT) with explain_output: display(HTML(html)) tab.children = [explain_output] tab.set_title(0, 'Explain') display(tab) else: query_type = get_query_type(cell) headers = {} if query_type not in ['SELECT', 'CONSTRUCT', 'DESCRIBE'] else { 'Accept': 'application/sparql-results+json'} res = do_sparql_query(cell, self.graph_notebook_config.host, self.graph_notebook_config.port, self.graph_notebook_config.ssl, request_generator, headers, endpoint_prefix) store_to_ns(args.store_to, res, local_ns) titles = [] children = [] display(tab) table_output = widgets.Output(layout=DEFAULT_LAYOUT) # Assign an empty value so we can always display to table output. # We will only add it as a tab if the type of query allows it. # Because of this, the table_output will only be displayed on the DOM if the query was of type SELECT. table_html = "" query_type = get_query_type(cell) if query_type in ['SELECT', 'CONSTRUCT', 'DESCRIBE']: logger.debug('creating sparql network...') # some issues with displaying a datatable when not wrapped in an hbox and displayed last hbox = widgets.HBox([table_output], layout=DEFAULT_LAYOUT) titles.append('Table') children.append(hbox) expand_all = line == '--expand-all' sn = SPARQLNetwork(expand_all=expand_all) sn.extract_prefix_declarations_from_query(cell) try: sn.add_results(res) except ValueError as value_error: logger.debug(value_error) logger.debug(f'number of nodes is {len(sn.graph.nodes)}') if len(sn.graph.nodes) > 0: f = Force(network=sn, options=self.graph_notebook_vis_options) titles.append('Graph') children.append(f) logger.debug('added sparql network to tabs') rows_and_columns = get_rows_and_columns(res) if rows_and_columns is not None: table_id = f"table-{str(uuid.uuid4())[:8]}" table_html = sparql_table_template.render(columns=rows_and_columns['columns'], rows=rows_and_columns['rows'], guid=table_id) # Handling CONSTRUCT and DESCRIBE on their own because we want to maintain the previous result pattern # of showing a tsv with each line being a result binding in addition to new ones. if query_type == 'CONSTRUCT' or query_type == 'DESCRIBE': lines = [] for b in res['results']['bindings']: lines.append(f'{b["subject"]["value"]}\t{b["predicate"]["value"]}\t{b["object"]["value"]}') raw_output = widgets.Output(layout=DEFAULT_LAYOUT) with raw_output: html = sparql_construct_template.render(lines=lines) display(HTML(html)) children.append(raw_output) titles.append('Raw') json_output = widgets.Output(layout=DEFAULT_LAYOUT) with json_output: print(json.dumps(res, indent=2)) children.append(json_output) titles.append('JSON') tab.children = children for i in range(len(titles)): tab.set_title(i, titles[i]) with table_output: display(HTML(table_html))