Exemplo n.º 1
0
    def test_do_gremlin_status_and_cancel(self):
        query = "g.V().out().out().out().out()"
        query_res = {}
        gremlin_query_thread = threading.Thread(
            target=self.do_gremlin_query_save_results,
            args=(
                query,
                query_res,
            ))
        gremlin_query_thread.start()
        time.sleep(3)

        query_id = ''
        request_generator = create_request_generator(AuthModeEnum.DEFAULT)
        status_res = do_gremlin_status(self.host, self.port, self.ssl,
                                       self.auth_mode, request_generator,
                                       query_id, False)
        self.assertEqual(type(status_res), dict)
        self.assertTrue('acceptedQueryCount' in status_res)
        self.assertTrue('runningQueryCount' in status_res)
        self.assertTrue(status_res['runningQueryCount'] == 1)
        self.assertTrue('queries' in status_res)

        query_id = ''
        for q in status_res['queries']:
            if query in q['queryString']:
                query_id = q['queryId']

        self.assertNotEqual(query_id, '')

        cancel_res = do_gremlin_cancel(self.host, self.port, self.ssl,
                                       self.auth_mode, request_generator,
                                       query_id)
        self.assertEqual(type(cancel_res), dict)
        self.assertTrue('status' in cancel_res)
        self.assertTrue('payload' in cancel_res)
        self.assertEqual('200 OK', cancel_res['status'])

        gremlin_query_thread.join()
        self.assertFalse('result' in query_res)
        self.assertTrue('error' in query_res)
        self.assertTrue('code' in query_res['error'])
        self.assertTrue('requestId' in query_res['error'])
        self.assertTrue('detailedMessage' in query_res['error'])
        self.assertTrue('TimeLimitExceededException' in query_res['error'])
    def test_do_sparql_status_and_cancel(self):
        query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 .} ORDER BY DESC(?s) LIMIT 100"
        query_res = {}
        sparql_query_thread = threading.Thread(
            target=self.do_sparql_query_save_result, args=(
                query,
                query_res,
            ))
        sparql_query_thread.start()
        time.sleep(1)

        query_id = ''
        request_generator = create_request_generator(
            AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV)
        status_res = do_sparql_status(self.host, self.port, self.ssl,
                                      self.request_generator, query_id)
        self.assertEqual(type(status_res), dict)
        self.assertTrue('acceptedQueryCount' in status_res)
        self.assertTrue('runningQueryCount' in status_res)
        self.assertTrue('queries' in status_res)

        time.sleep(1)

        query_id = ''
        for q in status_res['queries']:
            if query in q['queryString']:
                query_id = q['queryId']

        self.assertNotEqual(query_id, '')

        cancel_res = do_sparql_cancel(self.host, self.port, self.ssl,
                                      request_generator, query_id, False)
        self.assertEqual(type(cancel_res), dict)
        self.assertTrue('acceptedQueryCount' in cancel_res)
        self.assertTrue('runningQueryCount' in cancel_res)
        self.assertTrue('queries' in cancel_res)

        sparql_query_thread.join()
        self.assertFalse('result' in query_res)
        self.assertTrue('error' in query_res)
        self.assertTrue('code' in query_res['error'])
        self.assertTrue('requestId' in query_res['error'])
        self.assertTrue('detailedMessage' in query_res['error'])
        self.assertEqual('CancelledByUserException',
                         query_res['error']['code'])
Exemplo n.º 3
0
    def load_ids(self, line):
        credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type
        request_generator = create_request_generator(
            self.graph_notebook_config.auth_mode, credentials_provider_mode)
        res = get_loader_jobs(self.graph_notebook_config.host,
                              self.graph_notebook_config.port,
                              self.graph_notebook_config.ssl,
                              request_generator)
        ids = []
        if 'payload' in res and 'loadIds' in res['payload']:
            ids = res['payload']['loadIds']

        labels = [widgets.Label(value=label_id) for label_id in ids]

        if not labels:
            labels = [widgets.Label(value="No load IDs found.")]

        vbox = widgets.VBox(labels)
        display(vbox)
Exemplo n.º 4
0
    def load_status(self, line, local_ns: dict = None):
        parser = argparse.ArgumentParser()
        parser.add_argument('load_id',
                            default='',
                            help='loader id to check status for')
        parser.add_argument('--store-to', type=str, default='')
        args = parser.parse_args(line.split())

        credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type
        request_generator = create_request_generator(
            self.graph_notebook_config.auth_mode, credentials_provider_mode)
        res = get_load_status(self.graph_notebook_config.host,
                              self.graph_notebook_config.port,
                              self.graph_notebook_config.ssl,
                              request_generator, args.load_id)
        print(json.dumps(res, indent=2))

        if args.store_to != '' and local_ns is not None:
            local_ns[args.store_to] = res
    def test_do_sparql_status_and_cancel_silently(self):
        query = "SELECT * WHERE { ?s ?p ?o . ?s2 ?p2 ?o2 .?s3 ?p3 ?o3 .} ORDER BY DESC(?s) LIMIT 100"
        query_res = {}
        sparql_query_thread = threading.Thread(
            target=self.do_sparql_query_save_result, args=(
                query,
                query_res,
            ))
        sparql_query_thread.start()
        time.sleep(1)

        query_id = ''
        request_generator = create_request_generator(
            AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV)
        status_res = do_sparql_status(self.host, self.port, self.ssl,
                                      request_generator, query_id)
        self.assertEqual(type(status_res), dict)
        self.assertTrue('acceptedQueryCount' in status_res)
        self.assertTrue('runningQueryCount' in status_res)
        self.assertTrue('queries' in status_res)

        query_id = ''
        for q in status_res['queries']:
            if query in q['queryString']:
                query_id = q['queryId']

        self.assertNotEqual(query_id, '')

        cancel_res = do_sparql_cancel(self.host, self.port, self.ssl,
                                      request_generator, query_id, True)
        self.assertEqual(type(cancel_res), dict)
        self.assertTrue('acceptedQueryCount' in cancel_res)
        self.assertTrue('runningQueryCount' in cancel_res)
        self.assertTrue('queries' in cancel_res)

        sparql_query_thread.join()
        self.assertEqual(type(query_res['result']), dict)
        self.assertTrue('s3' in query_res['result']['head']['vars'])
        self.assertTrue('p3' in query_res['result']['head']['vars'])
        self.assertTrue('o3' in query_res['result']['head']['vars'])
        self.assertEqual([], query_res['result']['results']['bindings'])
Exemplo n.º 6
0
    def cancel_load(self, line, local_ns: dict = None):
        parser = argparse.ArgumentParser()
        parser.add_argument('load_id',
                            default='',
                            help='loader id to check status for')
        parser.add_argument('--store-to', type=str, default='')
        args = parser.parse_args(line.split())

        credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type
        request_generator = create_request_generator(
            self.graph_notebook_config.auth_mode, credentials_provider_mode)
        res = cancel_load(self.graph_notebook_config.host,
                          self.graph_notebook_config.port,
                          self.graph_notebook_config.ssl, request_generator,
                          args.load_id)
        if res:
            print('Cancelled successfully.')
        else:
            print('Something went wrong cancelling bulk load job.')

        if args.store_to != '' and local_ns is not None:
            local_ns[args.store_to] = res
Exemplo n.º 7
0
    def load_ids(self, line, local_ns: dict = None):
        parser = argparse.ArgumentParser()
        parser.add_argument('--store-to', type=str, default='')
        args = parser.parse_args(line.split())

        credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type
        request_generator = create_request_generator(self.graph_notebook_config.auth_mode, credentials_provider_mode)
        res = get_loader_jobs(self.graph_notebook_config.host, self.graph_notebook_config.port,
                              self.graph_notebook_config.ssl, request_generator)
        ids = []
        if 'payload' in res and 'loadIds' in res['payload']:
            ids = res['payload']['loadIds']

        labels = [widgets.Label(value=label_id) for label_id in ids]

        if not labels:
            labels = [widgets.Label(value="No load IDs found.")]

        vbox = widgets.VBox(labels)
        display(vbox)

        if args.store_to != '' and local_ns is not None:
            local_ns[args.store_to] = res
Exemplo n.º 8
0
 def test_create_request_generator_sparql(self):
     mode = AuthModeEnum.DEFAULT
     command = 'sparql'
     rpg = create_request_generator(mode, command=command)
     self.assertEqual(SPARQLRequestGenerator, type(rpg))
Exemplo n.º 9
0
    def load(self, line):
        # since this can be a long-running task, freezing variables in the case
        # that a user alters them in another command.
        host = self.graph_notebook_config.host
        port = self.graph_notebook_config.port
        ssl = self.graph_notebook_config.ssl

        credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type
        request_generator = create_request_generator(
            self.graph_notebook_config.auth_mode, credentials_provider_mode)
        load_role_arn = self.graph_notebook_config.load_from_s3_arn
        region = self.graph_notebook_config.aws_region

        button = widgets.Button(description="Submit")
        output = widgets.Output()
        source = widgets.Text(
            value='s3://',
            placeholder='Type something',
            description='Source:',
            disabled=False,
        )

        arn = widgets.Text(value=load_role_arn,
                           placeholder='Type something',
                           description='Load ARN:',
                           disabled=False)

        source_format = widgets.Dropdown(
            options=['', 'csv', 'ntriples', 'nquads', 'rdfxml', 'turtle'],
            value=
            '',  # blank so the user has to choose a format instead of risking the default one being incorrect.
            description='Format: ',
            disabled=False)

        region_box = widgets.Text(value=region,
                                  placeholder='us-east-1',
                                  description='AWS Region:',
                                  disabled=False)

        fail_on_error = widgets.Dropdown(options=['TRUE', 'FALSE'],
                                         value='TRUE',
                                         description='Fail on Failure: ',
                                         disabled=False)

        parallelism = widgets.Dropdown(
            options=['LOW', 'MEDIUM', 'HIGH', 'OVERSUBSCRIBE'],
            value='HIGH',
            description='Parallelism :',
            disabled=False)

        update_single_cardinality = widgets.Dropdown(
            options=['TRUE', 'FALSE'],
            value='FALSE',
            description='Update Single Cardinality:',
            disabled=False,
        )

        source_hbox = widgets.HBox([source])
        arn_hbox = widgets.HBox([arn])
        source_format_hbox = widgets.HBox([source_format])

        display(source_hbox, source_format_hbox, region_box, arn_hbox,
                fail_on_error, parallelism, update_single_cardinality, button,
                output)

        def on_button_clicked(b):
            source_hbox.children = (source, )
            arn_hbox.children = (arn, )
            source_format_hbox.children = (source_format, )

            validated = True
            validation_label_style = DescriptionStyle(color='red')
            if not (source.value.startswith('s3://') and len(source.value) > 7
                    ) and not source.value.startswith('/'):
                validated = False
                source_validation_label = widgets.HTML(
                    '<p style="color:red;">Source must be an s3 bucket or file path</p>'
                )
                source_validation_label.style = validation_label_style
                source_hbox.children += (source_validation_label, )

            if source_format.value == '':
                validated = False
                source_format_validation_label = widgets.HTML(
                    '<p style="color:red;">Format cannot be blank.</p>')
                source_format_hbox.children += (
                    source_format_validation_label, )

            if not arn.value.startswith('arn:aws') and source.value.startswith(
                    "s3://"
            ):  # only do this validation if we are using an s3 bucket.
                validated = False
                arn_validation_label = widgets.HTML(
                    '<p style="color:red;">Load ARN must start with "arn:aws"</p>'
                )
                arn_hbox.children += (arn_validation_label, )

            if not validated:
                return

            source_exp = os.path.expandvars(
                source.value
            )  # replace any env variables in source.value with their values, can use $foo or ${foo}. Particularly useful for ${AWS_REGION}
            logger.info(f'using source_exp: {source_exp}')
            try:
                load_result = do_load(host, port, source_format.value, ssl,
                                      str(source_exp), region_box.value,
                                      arn.value, fail_on_error.value,
                                      parallelism.value,
                                      update_single_cardinality.value,
                                      request_generator)

                source_hbox.close()
                source_format_hbox.close()
                region_box.close()
                arn_hbox.close()
                fail_on_error.close()
                parallelism.close()
                update_single_cardinality.close()
                button.close()
                output.close()

                if 'status' not in load_result or load_result[
                        'status'] != '200 OK':
                    with output:
                        print('Something went wrong.')
                        print(load_result)
                        logger.error(load_result)
                    return

                load_id_label = widgets.Label(
                    f'Load ID: {load_result["payload"]["loadId"]}')
                poll_interval = 5
                interval_output = widgets.Output()
                job_status_output = widgets.Output()

                load_id_hbox = widgets.HBox([load_id_label])
                status_hbox = widgets.HBox([interval_output])
                vbox = widgets.VBox(
                    [load_id_hbox, status_hbox, job_status_output])
                display(vbox)

                last_poll_time = time.time()
                while True:
                    time_elapsed = int(time.time() - last_poll_time)
                    time_remaining = poll_interval - time_elapsed
                    interval_output.clear_output()
                    if time_elapsed > poll_interval:
                        with interval_output:
                            print('checking status...')
                        job_status_output.clear_output()
                        with job_status_output:
                            display_html(HTML(loading_wheel_html))
                        try:
                            interval_check_response = get_load_status(
                                host, port, ssl, request_generator,
                                load_result['payload']['loadId'])
                        except Exception as e:
                            logger.error(e)
                            with job_status_output:
                                print(
                                    'Something went wrong updating job status. Ending.'
                                )
                                return
                        job_status_output.clear_output()
                        with job_status_output:
                            print(
                                f'Overall Status: {interval_check_response["payload"]["overallStatus"]["status"]}'
                            )
                            if interval_check_response["payload"][
                                    "overallStatus"][
                                        "status"] == 'LOAD_COMPLETED':
                                interval_output.close()
                                print('Done.')
                                return
                        last_poll_time = time.time()
                    else:
                        with interval_output:
                            print(
                                f'checking status in {time_remaining} seconds')
                    time.sleep(1)
            except HTTPError as httpEx:
                output.clear_output()
                with output:
                    print(httpEx.response.content.decode('utf-8'))

        button.on_click(on_button_clicked)
Exemplo n.º 10
0
    def db_reset(self, line):
        host = self.graph_notebook_config.host
        port = self.graph_notebook_config.port
        ssl = self.graph_notebook_config.ssl

        logger.info(f'calling system endpoint {host}')
        parser = argparse.ArgumentParser()
        parser.add_argument('-g',
                            '--generate-token',
                            action='store_true',
                            help='generate token for database reset')
        parser.add_argument('-t',
                            '--token',
                            nargs=1,
                            default='',
                            help='perform database reset with given token')
        parser.add_argument('-y',
                            '--yes',
                            action='store_true',
                            help='skip the prompt and perform database reset')
        args = parser.parse_args(line.split())
        generate_token = args.generate_token
        skip_prompt = args.yes
        request_generator = create_request_generator(
            self.graph_notebook_config.auth_mode,
            self.graph_notebook_config.iam_credentials_provider_type)
        logger.info(
            f'used credentials_provider_mode={self.graph_notebook_config.iam_credentials_provider_type.name} and auth_mode={self.graph_notebook_config.auth_mode.name} to make system request'
        )
        if generate_token is False and args.token == '':
            if skip_prompt:
                res = initiate_database_reset(host, port, ssl,
                                              request_generator)
                token = res['payload']['token']
                res = perform_database_reset(token, host, port, ssl,
                                             request_generator)
                logger.info(f'got the response {res}')
                return res

            output = widgets.Output()
            source = 'Are you sure you want to delete all the data in your cluster?'
            label = widgets.Label(source)
            text_hbox = widgets.HBox([label])
            check_box = widgets.Checkbox(
                value=False,
                disabled=False,
                indent=False,
                description=
                'I acknowledge that upon deletion the cluster data will no longer be available.',
                layout=widgets.Layout(width='600px', margin='5px 5px 5px 5px'))
            button_delete = widgets.Button(description="Delete")
            button_cancel = widgets.Button(description="Cancel")
            button_hbox = widgets.HBox([button_delete, button_cancel])

            display(text_hbox, check_box, button_hbox, output)

            def on_button_delete_clicked(b):
                result = initiate_database_reset(host, port, ssl,
                                                 request_generator)

                text_hbox.close()
                check_box.close()
                button_delete.close()
                button_cancel.close()
                button_hbox.close()

                if not check_box.value:
                    with output:
                        print('Checkbox is not checked.')
                    return
                token = result['payload']['token']
                if token == "":
                    with output:
                        print('Failed to get token.')
                        print(result)
                    return

                result = perform_database_reset(token, host, port, ssl,
                                                request_generator)

                if 'status' not in result or result['status'] != '200 OK':
                    with output:
                        print(
                            'Database reset failed, please try the operation again or reboot the cluster.'
                        )
                        print(result)
                        logger.error(result)
                    return

                retry = 10
                poll_interval = 5
                interval_output = widgets.Output()
                job_status_output = widgets.Output()
                status_hbox = widgets.HBox([interval_output])
                vbox = widgets.VBox([status_hbox, job_status_output])
                display(vbox)

                last_poll_time = time.time()
                while retry > 0:
                    time_elapsed = int(time.time() - last_poll_time)
                    time_remaining = poll_interval - time_elapsed
                    interval_output.clear_output()
                    if time_elapsed > poll_interval:
                        with interval_output:
                            print('checking status...')
                        job_status_output.clear_output()
                        with job_status_output:
                            display_html(HTML(loading_wheel_html))
                        try:
                            retry -= 1
                            interval_check_response = get_status(
                                host, port, ssl, request_generator)
                        except Exception as e:
                            # Exception is expected when database is resetting, continue waiting
                            with job_status_output:
                                last_poll_time = time.time()
                                time.sleep(1)
                                continue
                        job_status_output.clear_output()
                        with job_status_output:
                            if interval_check_response["status"] == 'healthy':
                                interval_output.close()
                                print('Database has been reset.')
                                return
                        last_poll_time = time.time()
                    else:
                        with interval_output:
                            print(
                                f'checking status in {time_remaining} seconds')
                    time.sleep(1)
                with output:
                    print(result)
                    if interval_check_response["status"] != 'healthy':
                        print(
                            "Could not retrieve the status of the reset operation within the allotted time. "
                            "If the database is not healthy after 1 min, please try the operation again or "
                            "reboot the cluster.")

            def on_button_cancel_clicked(b):
                text_hbox.close()
                check_box.close()
                button_delete.close()
                button_cancel.close()
                button_hbox.close()
                with output:
                    print('Database reset operation has been canceled.')

            button_delete.on_click(on_button_delete_clicked)
            button_cancel.on_click(on_button_cancel_clicked)
            return
        elif generate_token:
            res = initiate_database_reset(host, port, ssl, request_generator)
        else:
            # args.token is an array of a single string, e.g., args.token=['ade-23-c23'], use index 0 to take the string
            res = perform_database_reset(args.token[0], host, port, ssl,
                                         request_generator)

        logger.info(f'got the response {res}')
        return res
Exemplo n.º 11
0
    def gremlin(self, line, cell):
        parser = argparse.ArgumentParser()
        parser.add_argument(
            'query_mode',
            nargs='?',
            default='query',
            help='query mode (default=query) [query|explain|profile]')
        parser.add_argument('-p',
                            '--path-pattern',
                            default='',
                            help='path pattern')
        args = parser.parse_args(line.split())
        mode = str_to_query_mode(args.query_mode)

        tab = widgets.Tab()
        if mode == QueryMode.EXPLAIN:
            request_generator = create_request_generator(
                self.graph_notebook_config.auth_mode,
                self.graph_notebook_config.iam_credentials_provider_type)
            raw_html = gremlin_explain(cell, self.graph_notebook_config.host,
                                       self.graph_notebook_config.port,
                                       self.graph_notebook_config.ssl,
                                       request_generator)
            explain_output = widgets.Output(layout=DEFAULT_LAYOUT)
            with explain_output:
                display(HTML(raw_html))
            tab.children = [explain_output]
            tab.set_title(0, 'Explain')
            display(tab)
        elif mode == QueryMode.PROFILE:
            request_generator = create_request_generator(
                self.graph_notebook_config.auth_mode,
                self.graph_notebook_config.iam_credentials_provider_type)
            raw_html = gremlin_profile(cell, self.graph_notebook_config.host,
                                       self.graph_notebook_config.port,
                                       self.graph_notebook_config.ssl,
                                       request_generator)
            profile_output = widgets.Output(layout=DEFAULT_LAYOUT)
            with profile_output:
                display(HTML(raw_html))
            tab.children = [profile_output]
            tab.set_title(0, 'Profile')
            display(tab)
        else:
            client_provider = create_client_provider(
                self.graph_notebook_config.auth_mode,
                self.graph_notebook_config.iam_credentials_provider_type)
            res = do_gremlin_query(cell, self.graph_notebook_config.host,
                                   self.graph_notebook_config.port,
                                   self.graph_notebook_config.ssl,
                                   client_provider)
            children = []
            titles = []

            table_output = widgets.Output(layout=DEFAULT_LAYOUT)
            titles.append('Console')
            children.append(table_output)

            try:
                gn = GremlinNetwork()
                if args.path_pattern == '':
                    gn.add_results(res)
                else:
                    pattern = parse_pattern_list_str(args.path_pattern)
                    gn.add_results_with_pattern(res, pattern)
                logger.debug(f'number of nodes is {len(gn.graph.nodes)}')
                if len(gn.graph.nodes) > 0:
                    f = Force(network=gn,
                              options=self.graph_notebook_vis_options)
                    titles.append('Graph')
                    children.append(f)
                    logger.debug('added gremlin network to tabs')
            except ValueError as value_error:
                logger.debug(
                    f'unable to create gremlin network from result. Skipping from result set: {value_error}'
                )

            tab.children = children
            for i in range(len(titles)):
                tab.set_title(i, titles[i])
            display(tab)

            table_id = f"table-{str(uuid.uuid4()).replace('-', '')[:8]}"
            table_html = gremlin_table_template.render(guid=table_id,
                                                       results=res)
            with table_output:
                display(HTML(table_html))
Exemplo n.º 12
0
    def sparql(self, line='', cell=''):
        request_generator = create_request_generator(
            self.graph_notebook_config.auth_mode,
            self.graph_notebook_config.iam_credentials_provider_type,
            command='sparql')

        if line != '':
            mode = str_to_query_mode(line)
        else:
            mode = self.mode
        tab = widgets.Tab()
        logger.debug(f'using mode={mode}')
        if mode == QueryMode.EXPLAIN:
            html_raw = sparql_explain(cell, self.graph_notebook_config.host,
                                      self.graph_notebook_config.port,
                                      self.graph_notebook_config.ssl,
                                      request_generator)
            explain_output = widgets.Output(layout=DEFAULT_LAYOUT)
            with explain_output:
                display(HTML(html_raw))

            tab.children = [explain_output]
            tab.set_title(0, 'Explain')
            display(tab)
        else:
            query_type = get_query_type(cell)
            headers = {} if query_type not in [
                'SELECT', 'CONSTRUCT', 'DESCRIBE'
            ] else {
                'Accept': 'application/sparql-results+json'
            }
            res = do_sparql_query(cell, self.graph_notebook_config.host,
                                  self.graph_notebook_config.port,
                                  self.graph_notebook_config.ssl,
                                  request_generator, headers)
            titles = []
            children = []

            display(tab)
            table_output = widgets.Output(layout=DEFAULT_LAYOUT)
            # Assign an empty value so we can always display to table output.
            # We will only add it as a tab if the type of query allows it.
            # Because of this, the table_output will only be displayed on the DOM if the query was of type SELECT.
            table_html = ""
            query_type = get_query_type(cell)
            if query_type in ['SELECT', 'CONSTRUCT', 'DESCRIBE']:
                logger.debug('creating sparql network...')

                # some issues with displaying a datatable when not wrapped in an hbox and displayed last
                hbox = widgets.HBox([table_output], layout=DEFAULT_LAYOUT)
                titles.append('Table')
                children.append(hbox)

                expand_all = line == '--expand-all'
                sn = SPARQLNetwork(expand_all=expand_all)
                sn.extract_prefix_declarations_from_query(cell)
                try:
                    sn.add_results(res)
                except ValueError as value_error:
                    logger.debug(value_error)

                logger.debug(f'number of nodes is {len(sn.graph.nodes)}')
                if len(sn.graph.nodes) > 0:
                    f = Force(network=sn,
                              options=self.graph_notebook_vis_options)
                    titles.append('Graph')
                    children.append(f)
                    logger.debug('added sparql network to tabs')

                rows_and_columns = get_rows_and_columns(res)
                if rows_and_columns is not None:
                    table_id = f"table-{str(uuid.uuid4())[:8]}"
                    table_html = sparql_table_template.render(
                        columns=rows_and_columns['columns'],
                        rows=rows_and_columns['rows'],
                        guid=table_id)

                # Handling CONSTRUCT and DESCRIBE on their own because we want to maintain the previous result pattern
                # of showing a tsv with each line being a result binding in addition to new ones.
                if query_type == 'CONSTRUCT' or query_type == 'DESCRIBE':
                    lines = []
                    for b in res['results']['bindings']:
                        lines.append(
                            f'{b["subject"]["value"]}\t{b["predicate"]["value"]}\t{b["object"]["value"]}'
                        )
                    raw_output = widgets.Output(layout=DEFAULT_LAYOUT)
                    with raw_output:
                        html = sparql_construct_template.render(lines=lines)
                        display(HTML(html))
                    children.append(raw_output)
                    titles.append('Raw')

            json_output = widgets.Output(layout=DEFAULT_LAYOUT)
            with json_output:
                print(json.dumps(res, indent=2))
            children.append(json_output)
            titles.append('JSON')

            tab.children = children
            for i in range(len(titles)):
                tab.set_title(i, titles[i])

            with table_output:
                display(HTML(table_html))
Exemplo n.º 13
0
    def gremlin(self, line, cell, local_ns: dict = None):
        parser = argparse.ArgumentParser()
        parser.add_argument('query_mode', nargs='?', default='query',
                            help='query mode (default=query) [query|explain|profile]')
        parser.add_argument('-p', '--path-pattern', default='', help='path pattern')
        parser.add_argument('-g', '--group-by', default='T.label', help='Property used to group nodes (e.g. code, T.region) default is T.label')
        parser.add_argument('--store-to', type=str, default='', help='store query result to this variable')
        parser.add_argument('--ignore-groups', action='store_true', help="Ignore all grouping options")
        args = parser.parse_args(line.split())
        mode = str_to_query_mode(args.query_mode)
        logger.debug(f'Arguments {args}')

        tab = widgets.Tab()
        if mode == QueryMode.EXPLAIN:
            request_generator = create_request_generator(self.graph_notebook_config.auth_mode,
                                                         self.graph_notebook_config.iam_credentials_provider_type)

            query_res = do_gremlin_explain(cell, self.graph_notebook_config.host, self.graph_notebook_config.port,
                                           self.graph_notebook_config.ssl, request_generator)
            if 'explain' in query_res:
                html = pre_container_template.render(content=query_res['explain'])
            else:
                html = pre_container_template.render(content='No explain found')

            explain_output = widgets.Output(layout=DEFAULT_LAYOUT)
            with explain_output:
                display(HTML(html))
            tab.children = [explain_output]
            tab.set_title(0, 'Explain')
            display(tab)
        elif mode == QueryMode.PROFILE:
            request_generator = create_request_generator(self.graph_notebook_config.auth_mode,
                                                         self.graph_notebook_config.iam_credentials_provider_type)

            query_res = do_gremlin_profile(cell, self.graph_notebook_config.host, self.graph_notebook_config.port,
                                           self.graph_notebook_config.ssl, request_generator)
            if 'profile' in query_res:
                html = pre_container_template.render(content=query_res['profile'])
            else:
                html = pre_container_template.render(content='No profile found')
            profile_output = widgets.Output(layout=DEFAULT_LAYOUT)
            with profile_output:
                display(HTML(html))
            tab.children = [profile_output]
            tab.set_title(0, 'Profile')
            display(tab)
        else:
            client_provider = create_client_provider(self.graph_notebook_config.auth_mode,
                                                     self.graph_notebook_config.iam_credentials_provider_type)
            query_res = do_gremlin_query(cell, self.graph_notebook_config.host, self.graph_notebook_config.port,
                                         self.graph_notebook_config.ssl, client_provider)
            children = []
            titles = []

            table_output = widgets.Output(layout=DEFAULT_LAYOUT)
            titles.append('Console')
            children.append(table_output)

            try:
                logger.debug(f'groupby: {args.group_by}')
                if args.ignore_groups:
                    gn = GremlinNetwork()
                else:
                    gn = GremlinNetwork(group_by_property=args.group_by)

                if args.path_pattern == '':
                    gn.add_results(query_res)
                else:
                    pattern = parse_pattern_list_str(args.path_pattern)
                    gn.add_results_with_pattern(query_res, pattern)
                logger.debug(f'number of nodes is {len(gn.graph.nodes)}')
                if len(gn.graph.nodes) > 0:
                    f = Force(network=gn, options=self.graph_notebook_vis_options)
                    titles.append('Graph')
                    children.append(f)
                    logger.debug('added gremlin network to tabs')
            except ValueError as value_error:
                logger.debug(f'unable to create gremlin network from result. Skipping from result set: {value_error}')

            tab.children = children
            for i in range(len(titles)):
                tab.set_title(i, titles[i])
            display(tab)

            table_id = f"table-{str(uuid.uuid4()).replace('-', '')[:8]}"
            table_html = gremlin_table_template.render(guid=table_id, results=query_res)
            with table_output:
                display(HTML(table_html))

        store_to_ns(args.store_to, query_res, local_ns)
Exemplo n.º 14
0
 def test_do_status_with_iam_credentials(self):
     request_generator = create_request_generator(
         AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV)
     status = get_status(self.host, self.port, self.ssl, request_generator)
     self.assertEqual(status['status'], 'healthy')
 def test_do_sparql_cancel_non_str_query_id(self):
     with self.assertRaises(ValueError):
         query_id = 42
         request_generator = create_request_generator(AuthModeEnum.DEFAULT)
         do_sparql_cancel(query_id, False, self.host, self.port, self.ssl, request_generator)
Exemplo n.º 16
0
        def on_button_clicked(b=None):
            submit_button.close()
            language_dropdown.disabled = True
            data_set_drop_down.disabled = True
            language = language_dropdown.value.lower()
            data_set = data_set_drop_down.value.lower()
            with output:
                print(f'Loading data set {data_set} with language {language}')
            queries = get_queries(language, data_set)
            if len(queries) < 1:
                with output:
                    print('Did not find any queries for the given dataset')
                return

            load_index = 1  # start at 1 to have a non-empty progress bar
            progress = widgets.IntProgress(
                value=load_index,
                min=0,
                max=len(queries) + 1,  # len + 1 so we can start at index 1
                orientation='horizontal',
                bar_style='info',
                description='Loading:')

            with progress_output:
                display(progress)

            for q in queries:
                with output:
                    print(f'{progress.value}/{len(queries)}:\t{q["name"]}')
                # Just like with the load command, seed is long-running
                # as such, we want to obtain the values of host, port, etc. in case they
                # change during execution.
                host = self.graph_notebook_config.host
                port = self.graph_notebook_config.port
                auth_mode = self.graph_notebook_config.auth_mode
                ssl = self.graph_notebook_config.ssl

                if language == 'gremlin':
                    client_provider = create_client_provider(
                        auth_mode, self.graph_notebook_config.
                        iam_credentials_provider_type)
                    # IMPORTANT: We treat each line as its own query!
                    for line in q['content'].splitlines():
                        try:
                            do_gremlin_query(line, host, port, ssl,
                                             client_provider)
                        except GremlinServerError as gremlinEx:
                            try:
                                error = json.loads(
                                    gremlinEx.args[0]
                                    [5:])  # remove the leading error code.
                                content = json.dumps(error, indent=2)
                            except Exception:
                                content = {'error': gremlinEx}

                            with output:
                                print(content)
                            progress.close()
                            return
                        except Exception as e:
                            content = {'error': e}
                            with output:
                                print(content)
                            progress.close()
                            return
                else:
                    request_generator = create_request_generator(
                        auth_mode, self.graph_notebook_config.
                        iam_credentials_provider_type)
                    try:
                        do_sparql_query(q['content'], host, port, ssl,
                                        request_generator)
                    except HTTPError as httpEx:
                        # attempt to turn response into json
                        try:
                            error = json.loads(
                                httpEx.response.content.decode('utf-8'))
                            content = json.dumps(error, indent=2)
                        except Exception:
                            content = {'error': httpEx}
                        with output:
                            print(content)
                        progress.close()
                        return
                    except Exception as ex:
                        content = {'error': str(ex)}
                        with output:
                            print(content)

                        progress.close()
                        return

                progress.value += 1

            progress.close()
            with output:
                print('Done.')
            return
Exemplo n.º 17
0
 def test_do_gremlin_cancel_non_str_query_id(self):
     with self.assertRaises(ValueError):
         query_id = 42
         request_generator = create_request_generator(AuthModeEnum.DEFAULT)
         do_gremlin_cancel(self.host, self.port, self.ssl, self.auth_mode,
                           request_generator, query_id)
Exemplo n.º 18
0
    def load(self, line='', local_ns: dict = None):
        parser = argparse.ArgumentParser()
        parser.add_argument('-s', '--source', default='s3://')
        parser.add_argument(
            '-l',
            '--loader-arn',
            default=self.graph_notebook_config.load_from_s3_arn)
        parser.add_argument('-f',
                            '--format',
                            choices=LOADER_FORMAT_CHOICES,
                            default='')
        parser.add_argument('-p',
                            '--parallelism',
                            choices=PARALLELISM_OPTIONS,
                            default=PARALLELISM_HIGH)
        parser.add_argument('-r',
                            '--region',
                            default=self.graph_notebook_config.aws_region)
        parser.add_argument('--fail-on-failure',
                            action='store_true',
                            default=False)
        parser.add_argument('--update-single-cardinality',
                            action='store_true',
                            default=True)
        parser.add_argument('--store-to',
                            type=str,
                            default='',
                            help='store query result to this variable')
        parser.add_argument('--run', action='store_true', default=False)

        args = parser.parse_args(line.split())

        # since this can be a long-running task, freezing variables in the case
        # that a user alters them in another command.
        host = self.graph_notebook_config.host
        port = self.graph_notebook_config.port
        ssl = self.graph_notebook_config.ssl

        credentials_provider_mode = self.graph_notebook_config.iam_credentials_provider_type
        request_generator = create_request_generator(
            self.graph_notebook_config.auth_mode, credentials_provider_mode)
        region = self.graph_notebook_config.aws_region

        button = widgets.Button(description="Submit")
        output = widgets.Output()
        source = widgets.Text(
            value=args.source,
            placeholder='Type something',
            description='Source:',
            disabled=False,
        )

        arn = widgets.Text(value=args.loader_arn,
                           placeholder='Type something',
                           description='Load ARN:',
                           disabled=False)

        source_format = widgets.Dropdown(options=LOADER_FORMAT_CHOICES,
                                         value=args.format,
                                         description='Format: ',
                                         disabled=False)

        region_box = widgets.Text(value=region,
                                  placeholder=args.region,
                                  description='AWS Region:',
                                  disabled=False)

        fail_on_error = widgets.Dropdown(options=['TRUE', 'FALSE'],
                                         value=str(
                                             args.fail_on_failure).upper(),
                                         description='Fail on Failure: ',
                                         disabled=False)

        parallelism = widgets.Dropdown(options=PARALLELISM_OPTIONS,
                                       value=args.parallelism,
                                       description='Parallelism :',
                                       disabled=False)

        update_single_cardinality = widgets.Dropdown(
            options=['TRUE', 'FALSE'],
            value=str(args.update_single_cardinality).upper(),
            description='Update Single Cardinality:',
            disabled=False,
        )

        source_hbox = widgets.HBox([source])
        arn_hbox = widgets.HBox([arn])
        source_format_hbox = widgets.HBox([source_format])

        display(source_hbox, source_format_hbox, region_box, arn_hbox,
                fail_on_error, parallelism, update_single_cardinality, button,
                output)

        def on_button_clicked(b):
            source_hbox.children = (source, )
            arn_hbox.children = (arn, )
            source_format_hbox.children = (source_format, )

            validated = True
            validation_label_style = DescriptionStyle(color='red')
            if not (source.value.startswith('s3://') and len(source.value) > 7
                    ) and not source.value.startswith('/'):
                validated = False
                source_validation_label = widgets.HTML(
                    '<p style="color:red;">Source must be an s3 bucket or file path</p>'
                )
                source_validation_label.style = validation_label_style
                source_hbox.children += (source_validation_label, )

            if source_format.value == '':
                validated = False
                source_format_validation_label = widgets.HTML(
                    '<p style="color:red;">Format cannot be blank.</p>')
                source_format_hbox.children += (
                    source_format_validation_label, )

            if not arn.value.startswith('arn:aws') and source.value.startswith(
                    "s3://"
            ):  # only do this validation if we are using an s3 bucket.
                validated = False
                arn_validation_label = widgets.HTML(
                    '<p style="color:red;">Load ARN must start with "arn:aws"</p>'
                )
                arn_hbox.children += (arn_validation_label, )

            if not validated:
                return

            source_exp = os.path.expandvars(
                source.value
            )  # replace any env variables in source.value with their values, can use $foo or ${foo}. Particularly useful for ${AWS_REGION}
            logger.info(f'using source_exp: {source_exp}')
            try:
                load_result = do_load(host, port, source_format.value, ssl,
                                      str(source_exp), region_box.value,
                                      arn.value, fail_on_error.value,
                                      parallelism.value,
                                      update_single_cardinality.value,
                                      request_generator)
                store_to_ns(args.store_to, load_result, local_ns)

                source_hbox.close()
                source_format_hbox.close()
                region_box.close()
                arn_hbox.close()
                fail_on_error.close()
                parallelism.close()
                update_single_cardinality.close()
                button.close()
                output.close()

                if 'status' not in load_result or load_result[
                        'status'] != '200 OK':
                    with output:
                        print('Something went wrong.')
                        print(load_result)
                        logger.error(load_result)
                    return

                load_id_label = widgets.Label(
                    f'Load ID: {load_result["payload"]["loadId"]}')
                poll_interval = 5
                interval_output = widgets.Output()
                job_status_output = widgets.Output()

                load_id_hbox = widgets.HBox([load_id_label])
                status_hbox = widgets.HBox([interval_output])
                vbox = widgets.VBox(
                    [load_id_hbox, status_hbox, job_status_output])
                display(vbox)

                last_poll_time = time.time()
                while True:
                    time_elapsed = int(time.time() - last_poll_time)
                    time_remaining = poll_interval - time_elapsed
                    interval_output.clear_output()
                    if time_elapsed > poll_interval:
                        with interval_output:
                            print('checking status...')
                        job_status_output.clear_output()
                        with job_status_output:
                            display_html(HTML(loading_wheel_html))
                        try:
                            interval_check_response = get_load_status(
                                host, port, ssl, request_generator,
                                load_result['payload']['loadId'])
                        except Exception as e:
                            logger.error(e)
                            with job_status_output:
                                print(
                                    'Something went wrong updating job status. Ending.'
                                )
                                return
                        job_status_output.clear_output()
                        with job_status_output:
                            print(
                                f'Overall Status: {interval_check_response["payload"]["overallStatus"]["status"]}'
                            )
                            if interval_check_response["payload"][
                                    "overallStatus"][
                                        "status"] in FINAL_LOAD_STATUSES:
                                interval_output.close()
                                print('Done.')
                                return
                        last_poll_time = time.time()
                    else:
                        with interval_output:
                            print(
                                f'checking status in {time_remaining} seconds')
                    time.sleep(1)
            except HTTPError as httpEx:
                output.clear_output()
                with output:
                    print(httpEx.response.content.decode('utf-8'))

        button.on_click(on_button_clicked)
        if args.run:
            on_button_clicked(None)
 def setUp(self) -> None:
     request_generator = create_request_generator(AuthModeEnum.DEFAULT)
     res = do_sparql_status(self.host, self.port, self.ssl, request_generator)
     for q in res['queries']:
         do_sparql_cancel(self.host, self.port, self.ssl, request_generator, q['queryId'], False)
Exemplo n.º 20
0
 def test_create_request_generator_default(self):
     mode = AuthModeEnum.DEFAULT
     rpg = create_request_generator(mode)
     self.assertEqual(DefaultRequestGenerator, type(rpg))
 def test_do_gremlin_cancel_non_str_query_id(self):
     with self.assertRaises(ValueError):
         query_id = 42
         request_generator = create_request_generator(AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV)
         do_gremlin_cancel(self.host, self.port, self.ssl, AuthModeEnum.IAM, request_generator, query_id)
Exemplo n.º 22
0
 def test_create_request_generator_iam_env(self):
     mode = AuthModeEnum.IAM
     rpg = create_request_generator(mode, IAMAuthCredentialsProvider.ENV)
     self.assertEqual(IamRequestGenerator, type(rpg))
     self.assertEqual(EnvCredentialsProvider,
                      type(rpg.credentials_provider))
Exemplo n.º 23
0
 def test_do_db_reset_initiate_with_iam_credentials(self):
     request_generator = create_request_generator(
         AuthModeEnum.IAM, IAMAuthCredentialsProvider.ENV)
     result = initiate_database_reset(self.host, self.port, self.ssl,
                                      request_generator)
     self.assertNotEqual(result['payload']['token'], '')
Exemplo n.º 24
0
    def sparql(self, line='', cell='', local_ns: dict = None):
        parser = argparse.ArgumentParser()
        parser.add_argument('query_mode', nargs='?', default='query',
                            help='query mode (default=query) [query|explain]')
        parser.add_argument('--endpoint-prefix', '-e', default='',
                            help='prefix path to sparql endpoint. For example, if "foo/bar" were specified, the endpoint called would be /foo/bar/sparql')
        parser.add_argument('--expand-all', action='store_true')

        request_generator = create_request_generator(self.graph_notebook_config.auth_mode,
                                                     self.graph_notebook_config.iam_credentials_provider_type,
                                                     command='sparql')
        parser.add_argument('--store-to', type=str, default='', help='store query result to this variable')
        args = parser.parse_args(line.split())
        mode = str_to_query_mode(args.query_mode)

        endpoint_prefix = args.endpoint_prefix if args.endpoint_prefix != '' else self.graph_notebook_config.sparql.endpoint_prefix

        tab = widgets.Tab()
        logger.debug(f'using mode={mode}')
        if mode == QueryMode.EXPLAIN:
            res = do_sparql_explain(cell, self.graph_notebook_config.host, self.graph_notebook_config.port,
                                    self.graph_notebook_config.ssl, request_generator, path_prefix=endpoint_prefix)
            store_to_ns(args.store_to, res, local_ns)
            if 'error' in res:
                html = error_template.render(error=json.dumps(res['error'], indent=2))
            else:
                html = sparql_explain_template.render(table=res)
            explain_output = widgets.Output(layout=DEFAULT_LAYOUT)
            with explain_output:
                display(HTML(html))

            tab.children = [explain_output]
            tab.set_title(0, 'Explain')
            display(tab)
        else:
            query_type = get_query_type(cell)
            headers = {} if query_type not in ['SELECT', 'CONSTRUCT', 'DESCRIBE'] else {
                'Accept': 'application/sparql-results+json'}
            res = do_sparql_query(cell, self.graph_notebook_config.host, self.graph_notebook_config.port,
                                  self.graph_notebook_config.ssl, request_generator, headers, endpoint_prefix)
            store_to_ns(args.store_to, res, local_ns)
            titles = []
            children = []

            display(tab)
            table_output = widgets.Output(layout=DEFAULT_LAYOUT)
            # Assign an empty value so we can always display to table output.
            # We will only add it as a tab if the type of query allows it.
            # Because of this, the table_output will only be displayed on the DOM if the query was of type SELECT.
            table_html = ""
            query_type = get_query_type(cell)
            if query_type in ['SELECT', 'CONSTRUCT', 'DESCRIBE']:
                logger.debug('creating sparql network...')

                # some issues with displaying a datatable when not wrapped in an hbox and displayed last
                hbox = widgets.HBox([table_output], layout=DEFAULT_LAYOUT)
                titles.append('Table')
                children.append(hbox)

                expand_all = line == '--expand-all'
                sn = SPARQLNetwork(expand_all=expand_all)
                sn.extract_prefix_declarations_from_query(cell)
                try:
                    sn.add_results(res)
                except ValueError as value_error:
                    logger.debug(value_error)

                logger.debug(f'number of nodes is {len(sn.graph.nodes)}')
                if len(sn.graph.nodes) > 0:
                    f = Force(network=sn, options=self.graph_notebook_vis_options)
                    titles.append('Graph')
                    children.append(f)
                    logger.debug('added sparql network to tabs')

                rows_and_columns = get_rows_and_columns(res)
                if rows_and_columns is not None:
                    table_id = f"table-{str(uuid.uuid4())[:8]}"
                    table_html = sparql_table_template.render(columns=rows_and_columns['columns'],
                                                              rows=rows_and_columns['rows'], guid=table_id)

                # Handling CONSTRUCT and DESCRIBE on their own because we want to maintain the previous result pattern
                # of showing a tsv with each line being a result binding in addition to new ones.
                if query_type == 'CONSTRUCT' or query_type == 'DESCRIBE':
                    lines = []
                    for b in res['results']['bindings']:
                        lines.append(f'{b["subject"]["value"]}\t{b["predicate"]["value"]}\t{b["object"]["value"]}')
                    raw_output = widgets.Output(layout=DEFAULT_LAYOUT)
                    with raw_output:
                        html = sparql_construct_template.render(lines=lines)
                        display(HTML(html))
                    children.append(raw_output)
                    titles.append('Raw')

            json_output = widgets.Output(layout=DEFAULT_LAYOUT)
            with json_output:
                print(json.dumps(res, indent=2))
            children.append(json_output)
            titles.append('JSON')

            tab.children = children
            for i in range(len(titles)):
                tab.set_title(i, titles[i])

            with table_output:
                display(HTML(table_html))