Пример #1
0
def test_retries(execute_mock):
    expected_retries = 3
    execute_mock.side_effect =Exception("fail")

    client = Client(
        retries=expected_retries,
        transport=RequestsHTTPTransport(url='http://swapi.graphene-python.org/graphql')
    )

    query = gql('''
    {
      myFavoriteFilm: film(id:"RmlsbToz") {
        id
        title
        episodeId
      }
    }
    ''')

    with pytest.raises(Exception):
        client.execute(query)

    assert execute_mock.call_count == expected_retries
Пример #2
0
class AiDungeonApiClient:
    def __init__(self):
        self.url: str = 'wss://api.aidungeon.io/subscriptions'
        self.websocket = WebsocketsTransport(url=self.url)
        self.gql_client = Client(transport=self.websocket,
                                 # fetch_schema_from_transport=True,
        )
        self.account_id: str = ''
        self.access_token: str = ''

        self.single_player_mode_id: str = 'scenario:458612'
        self.scenario_id: str = '' # REVIEW: maybe call it setting_id ?
        self.character_name: str = ''
        self.story_pitch_template: str = ''
        self.story_pitch: str = ''
        self.adventure_id: str = ''
        self.public_id: str = ''
        self.quests: str = ''


    async def _execute_query_pseudo_async(self, query, params={}):
        async with Client(
                transport=self.websocket,
                # fetch_schema_from_transport=True,
        ) as session:
            return await session.execute(gql(query), variable_values=params)


    def _execute_query(self, query, params=None):
        return self.gql_client.execute(gql(query), variable_values=params)


    def update_session_access_token(self, access_token):
        self.websocket = WebsocketsTransport(
            url=self.url,
            init_payload={'token': access_token})
        self.gql_client = Client(transport=self.websocket,
                                 # fetch_schema_from_transport=True,
        )


    def user_login(self, email, password):
        debug_print("user login")
        result = self._execute_query('''
        mutation ($email: String, $password: String, $anonymousId: String) {  login(email: $email, password: $password, anonymousId: $anonymousId) {    id    accessToken    __typename  }}
        ''',
                                     {
                                         "email": email ,
                                         "password": password
                                     }
        )
        debug_print(result)
        self.account_id = result['login']['id']
        self.access_token = result['login']['accessToken']
        self.update_session_access_token(self.access_token)


    def anonymous_login(self):
        debug_print("anonymous login")
        result = self._execute_query('''
        mutation {  createAnonymousAccount {    id    accessToken    __typename  }}
        ''')
        debug_print(result)
        self.account_id = result['createAnonymousAccount']['id']
        self.access_token = result['createAnonymousAccount']['accessToken']
        self.update_session_access_token(self.access_token)



    def perform_init_handshake(self):
        # debug_print("query user details")
        # result = self._execute_query('''
        # {  user {    id    isDeveloper    hasPremium    lastAdventure {      id      mode      __typename    }    newProductUpdates {      id      title      description      createdAt      __typename    }    __typename  }}
        # ''')
        # debug_print(result)


        debug_print("add device token")
        result = self._execute_query('''
        mutation ($token: String, $platform: String) {  addDeviceToken(token: $token, platform: $platform)}
        ''',
                                     { 'token': 'web',
                                       'platform': 'web' })
        debug_print(result)


        debug_print("send event start premium")
        result = self._execute_query('''
        mutation ($input: EventInput) {  sendEvent(input: $input)}
        ''',
                                     {
                                         "input": {
                                             "eventName":"start_premium_v5",
                                             "variation":"dont",
                                             # "variation":"show",
                                             "platform":"web"
                                         }
                                     })
        debug_print(result)


    @staticmethod
    def normalize_options(raw_settings_list):
        settings_dict = {}
        for i, opts in enumerate(raw_settings_list, start=1):
            setting_id = opts['id']
            setting_name = opts['title']
            settings_dict[str(i)] = [setting_id, setting_name]
        return settings_dict


    def get_options(self, scenario_id):
        prompt = ''
        options = None

        debug_print("query options (variant #1)")
        result = self._execute_query('''
        query ($id: String) {  user {    id    username    __typename  }  content(id: $id) {    id    userId    contentType    contentId    prompt    gameState    options {      id      title      __typename    }    playPublicId    __typename  }}
        ''',
                                     {"id": scenario_id})
        debug_print(result)
        prompt = result['content']['prompt']
        if result['content']['options']:
            options = self.normalize_options(result['content']['options'])

        # debug_print("query options (variant #2)")
        # result = self._execute_query('''
        # query ($id: String) {  content(id: $id) {    id    contentType    contentId    title    description    prompt    memory    tags    nsfw    published    createdAt    updatedAt    deletedAt    options {      id      title      __typename    }    __typename  }}
        # ''',
        #                              {"id": scenario_id})
        # debug_print(result)
        # prompt = result['content']['prompt']
        # options = self.normalize_options(result['content']['options'])

        return [prompt, options]


    def get_settings_single_player(self):
        return self.get_options(self.single_player_mode_id)


    def get_characters(self):
        prompt = ''
        characters = {}

        debug_print("query settings singleplayer (variant #1)")
        result = self._execute_query('''
        query ($id: String) {  user {    id    username    __typename  }  content(id: $id) {    id    userId    contentType    contentId    prompt    gameState    options {      id      title      __typename    }    playPublicId    __typename  }}
        ''',
                                     {"id": self.scenario_id})
        debug_print(result)
        prompt = result['content']['prompt']
        characters = self.normalize_options(result['content']['options'])

        # debug_print("query settings singleplayer (variant #2)")
        # result = self._execute_query('''
        # query ($id: String) {  content(id: $id) {    id    contentType    contentId    title    description    prompt    memory    tags    nsfw    published    createdAt    updatedAt    deletedAt    options {      id      title      __typename    }    __typename  }}
        # ''',
        #                              {"id": self.scenario_id})
        # debug_print(result)
        # prompt = result['content']['prompt']
        # characters = self.normalize_options(result['content']['options'])

        return [prompt, characters]


    def get_story_for_scenario(self):

        debug_print("query get story for scenario")
        result = self._execute_query('''
        query ($id: String) {  user {    id    username    __typename  }  content(id: $id) {    id    userId    contentType    contentId    prompt    gameState    options {      id      title      __typename    }    playPublicId    __typename  }}
        ''',
                                     {"id": self.scenario_id})
        debug_print(result)
        self.story_pitch_template = result['content']['prompt']

        # debug_print("query get story for scenario #2")
        # result = self._execute_query('''
        # query ($id: String) {  content(id: $id) {    id    contentType    contentId    title    description    prompt    memory    tags    nsfw    published    createdAt    updatedAt    deletedAt    options {      id      title      __typename    }    __typename  }}
        # ''',
        #                              {"id": self.scenario_id})
        # debug_print(result)



    @staticmethod
    def initial_story_from_history_list(history_list):
        pitch = ''
        for entry in history_list:
            if not entry['type'] in ['story', 'continue']:
                break
            pitch += entry['text']
        return pitch


    def set_story_pitch(self):
        self.story_pitch = self.story_pitch_template.replace('${character.name}', self.character_name)


    def init_custom_story_pitch(self, user_input):

        debug_print("send custom settings story pitch")
        result = self._execute_query('''
        mutation ($input: ContentActionInput) {  sendAction(input: $input) {    id    actionLoading    memory    died    gameState    newQuests {      id      text      completed      active      __typename    }    actions {      id      text      __typename    }    __typename  }}
        ''',
                                     {
                                         "input": {
                                             "type": "story",
                                             "text": user_input,
                                             "id": self.adventure_id}})
        debug_print(result)
        self.story_pitch = ''.join([a['text'] for a in result['sendAction']['actions']])


    def _create_adventure(self, scenario_id):
        debug_print("create adventure")
        result = self._execute_query('''
        mutation ($id: String, $prompt: String) {  createAdventureFromScenarioId(id: $id, prompt: $prompt) {    id    contentType    contentId    title    description    musicTheme    tags    nsfw    published    createdAt    updatedAt    deletedAt    publicId    historyList    __typename  }}
        ''',
                                     {
                                         "id": scenario_id,
                                         "prompt": self.story_pitch
                                     })
        debug_print(result)
        self.adventure_id = result['createAdventureFromScenarioId']['id']
        if 'historyList' in result['createAdventureFromScenarioId']:
            # NB: not present when self.story_pitch is None, as is the case for a custom scenario
            self.story_pitch = self.initial_story_from_history_list(result['createAdventureFromScenarioId']['historyList'])


    def init_story(self):

        self._create_adventure(self.scenario_id)

        debug_print("get created adventure ids")
        result = self._execute_query('''
        query ($id: String, $playPublicId: String) {  content(id: $id, playPublicId: $playPublicId) {    id    historyList    quests    playPublicId    userId    __typename  }}
        ''',
                                     {
                                         "id": self.adventure_id,
                                     })
        debug_print(result)
        self.quests = result['content']['quests']
        self.public_id = result['content']['playPublicId']
        # self.story_pitch = self.initial_story_from_history_list(result['content']['historyList'])


    def perform_remember_action(self, user_input):
        debug_print("remember something")
        result = self._execute_query('''
        mutation ($input: ContentActionInput) {  updateMemory(input: $input) {    id    memory    __typename  }}
        ''',
                                     {
                                         "input":
                                         {
                                             "text": user_input,
                                             "type":"remember",
                                             "id": self.adventure_id
                                         }
                                     })
        debug_print(result)


    def perform_regular_action(self, action, user_input):

        story_continuation = ""


        debug_print("send regular action")
        result = self._execute_query('''
        mutation ($input: ContentActionInput) {  sendAction(input: $input) {    id    actionLoading    memory    died    gameState    __typename  }}
        ''',
                                     {
                                         "input": {
                                             "type": action,
                                             "text": user_input,
                                             "id": self.adventure_id
                                         }
                                     })
        debug_print(result)


        debug_print("get story continuation")
        result = self._execute_query('''
        query ($id: String, $playPublicId: String) {
            content(id: $id, playPublicId: $playPublicId) {
                id
                actions {
                    id
                    text
                }
            }
        }
        ''',
                                     {
                                         "id": self.adventure_id
                                     })
        debug_print(result)
        story_continuation = result['content']['actions'][-1]['text']

        return story_continuation
Пример #3
0
def goszakup_graphql_parsing(enity_name: str, url: str, token: str, start_date: str,
                             end_date: str, limit: str, timeout: int,
                             key_column: str, retries: int, struct=None):

    """ Return rows with data parsed from GraphQl service of goszakup.gov.kz.
    Note that this function only returns rows, doesn't not saves in file.
    That behavior can be changed in future if there's need, by additional parameter fpath.
    Using this convenient more for some piece of data, like short date range(day, week, no more ),
    not all data, cause we can get refusion(exception like TransportProtocolError) from source service.
    Therefore here we don't use retry mechanism(.prs files).
    """

    headers = dict()
    # token have got from goszakup.gov.kz
    headers['Authorization'] = token

    query_fpath = os.path.join(os.path.dirname(os.path.realpath(__file__)),
                               f'{enity_name.lower()}.gql')
    query = gql(read_file(query_fpath))

    # standart Http request
    # could be asyncronic, but there are server side restrictions
    transport = RequestsHTTPTransport(
        url=url,
        verify=False,
        retries=retries,
        headers=headers
    )

    # create graphql client
    client = Client(transport=transport, fetch_schema_from_transport=True)
    # parameters to filter data for specified date and limit portion of retrieving data
    params = {'from': start_date, 'to': end_date, 'limit': limit}

    start_from = None

    result = []

    while True:
        p = params

        # first request
        if start_from:
            p['after'] = start_from

        data = client.execute(query, variable_values=p)

        if data.get(enity_name) is None or len(data.get(enity_name, [])) == 0:
            break

        data = data.get(enity_name)

        # start position for next request (pagination)
        start_from = data[-1][key_column]

        data = [dict_to_csvrow(d, struct) for d in data]
        result.extend(data)

        # sleep to avoid blocking us on source  service
        if timeout:
            time.sleep(timeout)

    return result
Пример #4
0
from gql import Client, gql
from gql.transport.requests import RequestsHTTPTransport

transport = RequestsHTTPTransport(
    url="https://countries.trevorblades.com/", verify=True, retries=3,
)

client = Client(transport=transport, fetch_schema_from_transport=True)

query = gql(
    """
    query getContinents {
      continents {
        code
        name
      }
    }
"""
)

result = client.execute(query)
print(result)
Пример #5
0
 def test_test(self):
     client = Client(schema=graphql_schema)
     query = gql('''{ test }''')
     executed = client.execute(query)
     assert executed == {'test': 'test_result'}
Пример #6
0
from gql import gql, Client

client = Client(schema=schema)
query = gql('''
{
  hello
}
''')

client.execute(query)
class AnnotationsRepository:
    def __init__(self, session: Session):
        self.session = session
        self.client = Client(transport=AIOHTTPTransport(url=PAN_SUBGRAPH))

    def get_subgraph_annotation(self, annotation_id):
        try:
            tg_resp = self.client.execute(
                ANNOTATION_FILTER_QUERY, variable_values={"id": annotation_id})
        except ClientConnectionError:
            logger.warning(
                "The IPFS gateway server dropped the connection - skipping")
            return []
        return self._resolve_subgraph_response(tg_resp)

    def get_by_cid(self, annotation_id):
        annotations = (self.session.query(Annotation).filter(
            Annotation.subject_id == annotation_id).all())
        output = [a.to_dict() for a in annotations]
        return output

    @staticmethod
    def _resolve_subgraph_response(response):
        start_time = time.time()
        output = []
        for annotation in response.get("annotations", []):
            # resolve through IPFS gateway
            try:
                annotation_data = requests.get(
                    f"https://api.thegraph.com/ipfs/api/v0/cat?arg={annotation['cid']}",
                    timeout=1.5,
                ).json()
            except (json.JSONDecodeError, ReadTimeout, ConnectionError):
                logger.warning(
                    f"Failed to decode annotation CID: {annotation['cid']}")
                continue
            output.append(annotation_data)
        logger.info(
            f"Gateway content retrieval took {time.time() - start_time} seconds"
        )
        return output

    def list(self, filter_value, offset, limit):
        output = []
        if filter_value is None:
            logger.debug(
                f"Fetching annotations from DB with filter={filter_value} limit={limit} offset={offset}"
            )
            annotations = (self.session.query(Annotation).order_by(
                desc(Annotation.issuance_date)).offset(offset).limit(
                    limit).all())
            output = [a.to_dict() for a in annotations]
        elif filter_value is not None:
            logger.debug(f'Fetching annotations filtered by "{filter_value}" '
                         f"from DB with limit={limit} offset={offset}")
            annotations = (self.session.query(Annotation).order_by(
                desc(Annotation.issuance_date)).filter(
                    Annotation.original_content.like(f"%{filter_value}%")).
                           offset(offset).limit(limit).all())
            output = [a.to_dict() for a in annotations]

        return output
Пример #8
0
    'indicator_id': 23,
    'query': 'Alue == "Helsinki" & Muuttuja == "Kokonaispäästöt (1000t CO2-ekv.)"',
    'lambda_over': 'Ajoneuvoluokka',
    'lambda': lambda x: x["Henkilöautot"] + x["Moottoripyörät"] + x["Pakettiautot"] + x["Kuorma-autot"] + x["Linja-autot"],
}, {
    'name': 'Joukkoliikenteen kasvihuonekaasupäästöt',
    'px_file': 'data/aluesarjat_px/ymparistotilastot/02-energia/3-energiaperainen-kuormitus/l1-liikenne-khk-paastot.px',
    'px_topic': 'Ympäristötilastot/02_Energia/3_Energiaperainen_kuormitus/L1_Liikenne_KHK_paastot.px',
    'indicator_id': 151,
    'query': 'Alue == "Helsinki" & Muuttuja == "Kokonaispäästöt (1000t CO2-ekv.)"',
    'lambda_over': 'Ajoneuvoluokka',
    # FIXME: --> Laivat otettiin pois
    'lambda': lambda x: x["Linja-autot"] + x["Lähijunat"] + x["Metrot"] + x["Raitiovaunut"],
}]

result = client.execute(GET_PLAN_INDICATORS,
                        variable_values=dict(plan='hnh2035'))

indicators = result['planIndicators']
indicators_by_name = {x['name']: x for x in result['planIndicators']}
indicators_by_id = {int(x['id']): x for x in result['planIndicators']}

topic_latest_years = {}


def post_indicator_values(ind, s):
    pprint(ind)
    s.index = s.index.map(lambda x: str(x) + '-12-31')
    post_values(ind['id'], s)


def update_indicator(ind):
Пример #9
0
class Api(object):
    """W&B Internal Api wrapper

    Note:
        Settings are automatically overridden by looking for
        a `wandb/settings` file in the current working directory or it's parent
        directory.  If none can be found, we look in the current users home
        directory.

    Arguments:
        default_settings(`dict`, optional): If you aren't using a settings
        file or you wish to override the section to use in the settings file
        Override the settings here.
    """

    HTTP_TIMEOUT = env.get_http_timeout(10)

    def __init__(
        self,
        default_settings=None,
        load_settings=True,
        retry_timedelta=None,
        environ=os.environ,
    ):
        if retry_timedelta is None:
            retry_timedelta = datetime.timedelta(days=1)
        self._environ = environ
        self.default_settings = {
            "section": "default",
            "git_remote": "origin",
            "ignore_globs": [],
            "base_url": "https://api.wandb.ai",
        }
        self.retry_timedelta = retry_timedelta
        self.default_settings.update(default_settings or {})
        self.retry_uploads = 10
        self._settings = Settings(
            load_settings=load_settings,
            root_dir=self.default_settings.get("root_dir"))
        # self.git = GitRepo(remote=self.settings("git_remote"))
        self.git = None
        # Mutable settings set by the _file_stream_api
        self.dynamic_settings = {
            "system_sample_seconds": 2,
            "system_samples": 15,
            "heartbeat_seconds": 30,
        }
        self.client = Client(transport=RequestsHTTPTransport(
            headers={
                "User-Agent": self.user_agent,
                "X-WANDB-USERNAME": env.get_username(env=self._environ),
                "X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ),
            },
            use_json=True,
            # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
            # https://bugs.python.org/issue22889
            timeout=self.HTTP_TIMEOUT,
            auth=("api", self.api_key or ""),
            url="%s/graphql" % self.settings("base_url"),
        ))
        self.gql = retry.Retry(
            self.execute,
            retry_timedelta=retry_timedelta,
            check_retry_fn=util.no_retry_auth,
            retryable_exceptions=(RetryError, requests.RequestException),
        )
        self._current_run_id = None
        self._file_stream_api = None

    def reauth(self):
        """Ensures the current api key is set in the transport"""
        self.client.transport.auth = ("api", self.api_key or "")

    def relocate(self):
        """Ensures the current api points to the right server"""
        self.client.transport.url = "%s/graphql" % self.settings("base_url")

    def execute(self, *args, **kwargs):
        """Wrapper around execute that logs in cases of failure."""
        try:
            return self.client.execute(*args, **kwargs)
        except requests.exceptions.HTTPError as err:
            res = err.response
            logger.error("%s response executing GraphQL." % res.status_code)
            logger.error(res.text)
            self.display_gorilla_error_if_found(res)
            six.reraise(*sys.exc_info())

    def display_gorilla_error_if_found(self, res):
        try:
            data = res.json()
        except ValueError:
            return

        if "errors" in data and isinstance(data["errors"], list):
            for err in data["errors"]:
                if not err.get("message"):
                    continue
                wandb.termerror("Error while calling W&B API: {} ({})".format(
                    err["message"], res))

    def disabled(self):
        return self._settings.get(Settings.DEFAULT_SECTION,
                                  "disabled",
                                  fallback=False)

    def sync_spell(self, run, env=None):
        """Syncs this run with spell"""
        try:
            env = env or os.environ
            run.config["_wandb"]["spell_url"] = env.get("SPELL_RUN_URL")
            run.config.persist()
            try:
                url = run.get_url()
            except CommError as e:
                wandb.termerror("Unable to register run with spell.run: %s" %
                                str(e))
                return False
            return requests.put(
                env.get("SPELL_API_URL", "https://api.spell.run") +
                "/wandb_url",
                json={
                    "access_token": env.get("WANDB_ACCESS_TOKEN"),
                    "url": url
                },
                timeout=2,
            )
        except requests.RequestException:
            return False

    def save_patches(self, out_dir):
        """Save the current state of this repository to one or more patches.

        Makes one patch against HEAD and another one against the most recent
        commit that occurs in an upstream branch. This way we can be robust
        to history editing as long as the user never does "push -f" to break
        history on an upstream branch.

        Writes the first patch to <out_dir>/<DIFF_FNAME> and the second to
        <out_dir>/upstream_diff_<commit_id>.patch.

        Arguments:
            out_dir (str): Directory to write the patch files.
        """
        if not self.git.enabled:
            return False

        try:
            root = self.git.root
            if self.git.dirty:
                patch_path = os.path.join(out_dir,
                                          wandb_lib.filenames.DIFF_FNAME)
                if self.git.has_submodule_diff:
                    with open(patch_path, "wb") as patch:
                        # we diff against HEAD to ensure we get changes in the index
                        subprocess.check_call(
                            ["git", "diff", "--submodule=diff", "HEAD"],
                            stdout=patch,
                            cwd=root,
                            timeout=5,
                        )
                else:
                    with open(patch_path, "wb") as patch:
                        subprocess.check_call(["git", "diff", "HEAD"],
                                              stdout=patch,
                                              cwd=root,
                                              timeout=5)

            upstream_commit = self.git.get_upstream_fork_point()
            if upstream_commit and upstream_commit != self.git.repo.head.commit:
                sha = upstream_commit.hexsha
                upstream_patch_path = os.path.join(
                    out_dir, "upstream_diff_{}.patch".format(sha))
                if self.git.has_submodule_diff:
                    with open(upstream_patch_path, "wb") as upstream_patch:
                        subprocess.check_call(
                            ["git", "diff", "--submodule=diff", sha],
                            stdout=upstream_patch,
                            cwd=root,
                            timeout=5,
                        )
                else:
                    with open(upstream_patch_path, "wb") as upstream_patch:
                        subprocess.check_call(
                            ["git", "diff", sha],
                            stdout=upstream_patch,
                            cwd=root,
                            timeout=5,
                        )
        # TODO: A customer saw `ValueError: Reference at 'refs/remotes/origin/foo' does not exist`
        # so we now catch ValueError.  Catching this error feels too generic.
        except (
                ValueError,
                subprocess.CalledProcessError,
                subprocess.TimeoutExpired,
        ) as e:
            logger.error("Error generating diff: %s" % e)

    def set_current_run_id(self, run_id):
        self._current_run_id = run_id

    @property
    def current_run_id(self):
        return self._current_run_id

    @property
    def user_agent(self):
        return "W&B Internal Client %s" % __version__

    @property
    def api_key(self):
        auth = requests.utils.get_netrc_auth(self.api_url)
        key = None
        if auth:
            key = auth[-1]
        # Environment should take precedence
        if self._environ.get(env.API_KEY):
            key = self._environ.get(env.API_KEY)
        return key

    @property
    def api_url(self):
        return self.settings("base_url")

    @property
    def app_url(self):
        return wandb.util.app_url(self.api_url)

    def settings(self, key=None, section=None):
        """The settings overridden from the wandb/settings file.

        Arguments:
            key (str, optional): If provided only this setting is returned
            section (str, optional): If provided this section of the setting file is
            used, defaults to "default"

        Returns:
            A dict with the current settings

                {
                    "entity": "models",
                    "base_url": "https://api.wandb.ai",
                    "project": None
                }
        """
        result = self.default_settings.copy()
        result.update(self._settings.items(section=section))
        result.update({
            "entity":
            env.get_entity(
                self._settings.get(
                    Settings.DEFAULT_SECTION,
                    "entity",
                    fallback=result.get("entity"),
                ),
                env=self._environ,
            ),
            "project":
            env.get_project(
                self._settings.get(
                    Settings.DEFAULT_SECTION,
                    "project",
                    fallback=result.get("project"),
                ),
                env=self._environ,
            ),
            "base_url":
            env.get_base_url(
                self._settings.get(
                    Settings.DEFAULT_SECTION,
                    "base_url",
                    fallback=result.get("base_url"),
                ),
                env=self._environ,
            ),
            "ignore_globs":
            env.get_ignore(
                self._settings.get(
                    Settings.DEFAULT_SECTION,
                    "ignore_globs",
                    fallback=result.get("ignore_globs"),
                ),
                env=self._environ,
            ),
        })

        return result if key is None else result[key]

    def clear_setting(self, key, globally=False, persist=False):
        self._settings.clear(Settings.DEFAULT_SECTION,
                             key,
                             globally=globally,
                             persist=persist)

    def set_setting(self, key, value, globally=False, persist=False):
        self._settings.set(Settings.DEFAULT_SECTION,
                           key,
                           value,
                           globally=globally,
                           persist=persist)
        if key == "entity":
            env.set_entity(value, env=self._environ)
        elif key == "project":
            env.set_project(value, env=self._environ)
        elif key == "base_url":
            self.relocate()

    def parse_slug(self, slug, project=None, run=None):
        if slug and "/" in slug:
            parts = slug.split("/")
            project = parts[0]
            run = parts[1]
        else:
            project = project or self.settings().get("project")
            if project is None:
                raise CommError("No default project configured.")
            run = run or slug or env.get_run(env=self._environ)
            if run is None:
                run = "latest"
        return (project, run)

    @normalize_exceptions
    def viewer(self):
        query = gql("""
        query Viewer{
            viewer {
                id
                entity
                teams {
                    edges {
                        node {
                            name
                        }
                    }
                }
            }
        }
        """)
        res = self.gql(query)
        return res.get("viewer") or {}

    @normalize_exceptions
    def list_projects(self, entity=None):
        """Lists projects in W&B scoped by entity.

        Arguments:
            entity (str, optional): The entity to scope this project to.

        Returns:
                [{"id","name","description"}]
        """
        query = gql("""
        query Models($entity: String!) {
            models(first: 10, entityName: $entity) {
                edges {
                    node {
                        id
                        name
                        description
                    }
                }
            }
        }
        """)
        return self._flatten_edges(
            self.gql(
                query,
                variable_values={"entity": entity
                                 or self.settings("entity")})["models"])

    @normalize_exceptions
    def project(self, project, entity=None):
        """Retrive project

        Arguments:
            project (str): The project to get details for
            entity (str, optional): The entity to scope this project to.

        Returns:
                [{"id","name","repo","dockerImage","description"}]
        """
        query = gql("""
        query Models($entity: String, $project: String!) {
            model(name: $project, entityName: $entity) {
                id
                name
                repo
                dockerImage
                description
            }
        }
        """)
        return self.gql(query,
                        variable_values={
                            "entity": entity,
                            "project": project
                        })["model"]

    @normalize_exceptions
    def sweep(self, sweep, specs, project=None, entity=None):
        """Retrieve sweep.

        Arguments:
            sweep (str): The sweep to get details for
            specs (str): history specs
            project (str, optional): The project to scope this sweep to.
            entity (str, optional): The entity to scope this sweep to.

        Returns:
                [{"id","name","repo","dockerImage","description"}]
        """
        query = gql("""
        query Models($entity: String, $project: String!, $sweep: String!, $specs: [JSONString!]!) {
            model(name: $project, entityName: $entity) {
                sweep(sweepName: $sweep) {
                    id
                    name
                    method
                    state
                    description
                    config
                    createdAt
                    heartbeatAt
                    updatedAt
                    earlyStopJobRunning
                    bestLoss
                    controller
                    scheduler
                    runs {
                        edges {
                            node {
                                name
                                state
                                config
                                exitcode
                                heartbeatAt
                                shouldStop
                                failed
                                stopped
                                running
                                summaryMetrics
                                sampledHistory(specs: $specs)
                            }
                        }
                    }
                }
            }
        }
        """)
        entity = entity or self.settings("entity")
        project = project or self.settings("project")
        response = self.gql(
            query,
            variable_values={
                "entity": entity,
                "project": project,
                "sweep": sweep,
                "specs": specs,
            },
        )
        if response["model"] is None or response["model"]["sweep"] is None:
            raise ValueError("Sweep {}/{}/{} not found".format(
                entity, project, sweep))
        data = response["model"]["sweep"]
        if data:
            data["runs"] = self._flatten_edges(data["runs"])
        return data

    @normalize_exceptions
    def list_runs(self, project, entity=None):
        """Lists runs in W&B scoped by project.

        Arguments:
            project (str): The project to scope the runs to
            entity (str, optional): The entity to scope this project to.  Defaults to public models

        Returns:
                [{"id",name","description"}]
        """
        query = gql("""
        query Buckets($model: String!, $entity: String!) {
            model(name: $model, entityName: $entity) {
                buckets(first: 10) {
                    edges {
                        node {
                            id
                            name
                            displayName
                            description
                        }
                    }
                }
            }
        }
        """)
        return self._flatten_edges(
            self.gql(
                query,
                variable_values={
                    "entity": entity or self.settings("entity"),
                    "model": project or self.settings("project"),
                },
            )["model"]["buckets"])

    @normalize_exceptions
    def launch_run(self, command, project=None, entity=None, run_id=None):
        """Launch a run in the cloud.

        Arguments:
            command (str): The command to run
            program (str): The file to run
            project (str): The project to scope the runs to
            entity (str, optional): The entity to scope this project to.  Defaults to public models
            run_id (str, optional): The run_id to scope to

        Returns:
                [{"podName","status"}]
        """
        query = gql("""
        mutation launchRun(
            $entity: String
            $model: String
            $runId: String
            $image: String
            $command: String
            $patch: String
            $cwd: String
            $datasets: [String]
        ) {
            launchRun(input: {id: $runId, entityName: $entity, patch: $patch, modelName: $model,
                image: $image, command: $command, datasets: $datasets, cwd: $cwd}) {
                podName
                status
                runId
            }
        }
        """)
        run_id = run_id or self.current_run_id
        assert run_id, "run_id must be specified"
        patch = BytesIO()
        if self.git.dirty:
            self.git.repo.git.execute(["git", "diff"], output_stream=patch)
            patch.seek(0)
        cwd = "."
        if self.git.enabled:
            cwd = cwd + os.getcwd().replace(self.git.repo.working_dir, "")
        return self.gql(
            query,
            variable_values={
                "entity": entity or self.settings("entity"),
                "model": project or self.settings("project"),
                "command": command,
                "runId": run_id,
                "patch": patch.read().decode("utf8"),
                "cwd": cwd,
            },
        )

    @normalize_exceptions
    def run_config(self, project, run=None, entity=None):
        """Get the relevant configs for a run

        Arguments:
            project (str): The project to download, (can include bucket)
            run (str): The run to download
            entity (str, optional): The entity to scope this project to.
        """
        query = gql("""
        query Model($name: String!, $entity: String!, $run: String!) {
            model(name: $name, entityName: $entity) {
                bucket(name: $run) {
                    config
                    commit
                    patch
                    files(names: ["wandb-metadata.json"]) {
                        edges {
                            node {
                                url
                            }
                        }
                    }
                }
            }
        }
        """)
        run = run or self.current_run_id
        assert run, "run must be specified"
        response = self.gql(query,
                            variable_values={
                                "name": project,
                                "run": run,
                                "entity": entity
                            })
        if response["model"] is None:
            raise ValueError("Run {}/{}/{} not found".format(
                entity, project, run))
        run = response["model"]["bucket"]
        commit = run["commit"]
        patch = run["patch"]
        config = json.loads(run["config"] or "{}")
        if len(run["files"]["edges"]) > 0:
            url = run["files"]["edges"][0]["node"]["url"]
            res = requests.get(url)
            res.raise_for_status()
            metadata = res.json()
        else:
            metadata = {}
        return (commit, config, patch, metadata)

    @normalize_exceptions
    def run_resume_status(self, entity, project_name, name):
        """Check if a run exists and get resume information.

        Arguments:
            entity (str, optional): The entity to scope this project to.
            project_name (str): The project to download, (can include bucket)
            name (str): The run to download
        """
        query = gql("""
        query Model($project: String!, $entity: String, $name: String!) {
            model(name: $project, entityName: $entity) {
                id
                name
                entity {
                    id
                    name
                }

                bucket(name: $name, missingOk: true) {
                    id
                    name
                    summaryMetrics
                    displayName
                    logLineCount
                    historyLineCount
                    eventsLineCount
                    historyTail
                    eventsTail
                    config
                }
            }
        }
        """)

        response = self.gql(
            query,
            variable_values={
                "entity": entity,
                "project": project_name,
                "name": name,
            },
        )

        if "model" not in response or "bucket" not in (response["model"]
                                                       or {}):
            return None

        project = response["model"]
        self.set_setting("project", project_name)
        if "entity" in project:
            self.set_setting("entity", project["entity"]["name"])

        return project["bucket"]

    @normalize_exceptions
    def check_stop_requested(self, project_name, entity_name, run_id):
        query = gql("""
        query Model($projectName: String, $entityName: String, $runId: String!) {
            project(name:$projectName, entityName:$entityName) {
                run(name:$runId) {
                    stopped
                }
            }
        }
        """)
        run_id = run_id or self.current_run_id
        assert run_id, "run_id must be specified"
        response = self.gql(
            query,
            variable_values={
                "projectName": project_name,
                "entityName": entity_name,
                "runId": run_id,
            },
        )

        project = response.get("project", None)
        if not project:
            return False
        run = project.get("run", None)
        if not run:
            return False

        return run["stopped"]

    def format_project(self, project):
        return re.sub(r"\W+", "-", project.lower()).strip("-_")

    @normalize_exceptions
    def upsert_project(self, project, id=None, description=None, entity=None):
        """Create a new project

        Arguments:
            project (str): The project to create
            description (str, optional): A description of this project
            entity (str, optional): The entity to scope this project to.
        """
        mutation = gql("""
        mutation UpsertModel($name: String!, $id: String, $entity: String!, $description: String, $repo: String)  {
            upsertModel(input: { id: $id, name: $name, entityName: $entity, description: $description, repo: $repo }) {
                model {
                    name
                    description
                }
            }
        }
        """)
        response = self.gql(
            mutation,
            variable_values={
                "name": self.format_project(project),
                "entity": entity or self.settings("entity"),
                "description": description,
                "repo": self.git.remote_url,
                "id": id,
            },
        )
        return response["upsertModel"]["model"]

    @normalize_exceptions
    def pop_from_run_queue(self, entity=None, project=None):
        mutation = gql("""
        mutation popFromRunQueue($entity: String!, $project: String!)  {
            popFromRunQueue(input: { entityName: $entity, projectName: $project }) {
                runQueueItemId
                runSpec
            }
        }
        """)
        response = self.gql(mutation,
                            variable_values={
                                "entity": entity,
                                "project": project
                            })
        return response["popFromRunQueue"]

    @normalize_exceptions
    def upsert_run(
        self,
        id=None,
        name=None,
        project=None,
        host=None,
        group=None,
        tags=None,
        config=None,
        description=None,
        entity=None,
        state=None,
        display_name=None,
        notes=None,
        repo=None,
        job_type=None,
        program_path=None,
        commit=None,
        sweep_name=None,
        summary_metrics=None,
        num_retries=None,
    ):
        """Update a run

        Arguments:
            id (str, optional): The existing run to update
            name (str, optional): The name of the run to create
            group (str, optional): Name of the group this run is a part of
            project (str, optional): The name of the project
            config (dict, optional): The latest config params
            description (str, optional): A description of this project
            entity (str, optional): The entity to scope this project to.
            repo (str, optional): Url of the program's repository.
            state (str, optional): State of the program.
            job_type (str, optional): Type of job, e.g 'train'.
            program_path (str, optional): Path to the program.
            commit (str, optional): The Git SHA to associate the run with
            summary_metrics (str, optional): The JSON summary metrics
        """
        mutation = gql("""
        mutation UpsertBucket(
            $id: String, $name: String,
            $project: String,
            $entity: String!,
            $groupName: String,
            $description: String,
            $displayName: String,
            $notes: String,
            $commit: String,
            $config: JSONString,
            $host: String,
            $debug: Boolean,
            $program: String,
            $repo: String,
            $jobType: String,
            $state: String,
            $sweep: String,
            $tags: [String!],
            $summaryMetrics: JSONString,
        ) {
            upsertBucket(input: {
                id: $id,
                name: $name,
                groupName: $groupName,
                modelName: $project,
                entityName: $entity,
                description: $description,
                displayName: $displayName,
                notes: $notes,
                config: $config,
                commit: $commit,
                host: $host,
                debug: $debug,
                jobProgram: $program,
                jobRepo: $repo,
                jobType: $jobType,
                state: $state,
                sweep: $sweep,
                tags: $tags,
                summaryMetrics: $summaryMetrics,
            }) {
                bucket {
                    id
                    name
                    displayName
                    description
                    config
                    project {
                        id
                        name
                        entity {
                            id
                            name
                        }
                    }
                }
            }
        }
        """)
        if config is not None:
            config = json.dumps(config)
        if not description or description.isspace():
            description = None

        kwargs = {}
        if num_retries is not None:
            kwargs["num_retries"] = num_retries

        variable_values = {
            "id": id,
            "entity": entity or self.settings("entity"),
            "name": name,
            "project": project,
            "groupName": group,
            "tags": tags,
            "description": description,
            "config": config,
            "commit": commit,
            "displayName": display_name,
            "notes": notes,
            "host":
            None if self.settings().get("anonymous") == "true" else host,
            "debug": env.is_debug(env=self._environ),
            "repo": repo,
            "program": program_path,
            "jobType": job_type,
            "state": state,
            "sweep": sweep_name,
            "summaryMetrics": summary_metrics,
        }

        response = self.gql(mutation,
                            variable_values=variable_values,
                            **kwargs)

        run = response["upsertBucket"]["bucket"]
        project = run.get("project")
        if project:
            self.set_setting("project", project["name"])
            entity = project.get("entity")
            if entity:
                self.set_setting("entity", entity["name"])

        return response["upsertBucket"]["bucket"]

    @normalize_exceptions
    def upload_urls(self,
                    project,
                    files,
                    run=None,
                    entity=None,
                    description=None):
        """Generate temporary resumeable upload urls

        Arguments:
            project (str): The project to download
            files (list or dict): The filenames to upload
            run (str): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            (bucket_id, file_info)
            bucket_id: id of bucket we uploaded to
            file_info: A dict of filenames and urls, also indicates if this revision already has uploaded files.
                {
                    'weights.h5': { "url": "https://weights.url" },
                    'model.json': { "url": "https://model.json", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
                }
        """
        query = gql("""
        query Model($name: String!, $files: [String]!, $entity: String!, $run: String!, $description: String) {
            model(name: $name, entityName: $entity) {
                bucket(name: $run, desc: $description) {
                    id
                    files(names: $files) {
                        uploadHeaders
                        edges {
                            node {
                                name
                                url(upload: true)
                                updatedAt
                            }
                        }
                    }
                }
            }
        }
        """)
        run_id = run or self.current_run_id
        assert run, "run must be specified"
        entity = entity or self.settings("entity")
        query_result = self.gql(
            query,
            variable_values={
                "name": project,
                "run": run_id,
                "entity": entity,
                "description": description,
                "files": [file for file in files],
            },
        )

        run = query_result["model"]["bucket"]
        if run:
            result = {
                file["name"]: file
                for file in self._flatten_edges(run["files"])
            }
            return run["id"], run["files"]["uploadHeaders"], result
        else:
            raise CommError("Run does not exist {}/{}/{}.".format(
                entity, project, run_id))

    @normalize_exceptions
    def download_urls(self, project, run=None, entity=None):
        """Generate download urls

        Arguments:
            project (str): The project to download
            run (str): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            A dict of extensions and urls

                {
                    'weights.h5': { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
                    'model.json': { "url": "https://model.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }
                }
        """
        query = gql("""
        query Model($name: String!, $entity: String!, $run: String!)  {
            model(name: $name, entityName: $entity) {
                bucket(name: $run) {
                    files {
                        edges {
                            node {
                                name
                                url
                                md5
                                updatedAt
                            }
                        }
                    }
                }
            }
        }
        """)
        run = run or self.current_run_id
        assert run, "run must be specified"
        query_result = self.gql(
            query,
            variable_values={
                "name": project,
                "run": run,
                "entity": entity or self.settings("entity"),
            },
        )
        files = self._flatten_edges(query_result["model"]["bucket"]["files"])
        return {file["name"]: file for file in files if file}

    @normalize_exceptions
    def download_url(self, project, file_name, run=None, entity=None):
        """Generate download urls

        Arguments:
            project (str): The project to download
            file_name (str): The name of the file to download
            run (str): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            A dict of extensions and urls

                { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }

        """
        query = gql("""
        query Model($name: String!, $fileName: String!, $entity: String!, $run: String!)  {
            model(name: $name, entityName: $entity) {
                bucket(name: $run) {
                    files(names: [$fileName]) {
                        edges {
                            node {
                                name
                                url
                                md5
                                updatedAt
                            }
                        }
                    }
                }
            }
        }
        """)
        run = run or self.current_run_id
        assert run, "run must be specified"
        query_result = self.gql(
            query,
            variable_values={
                "name": project,
                "run": run,
                "fileName": file_name,
                "entity": entity or self.settings("entity"),
            },
        )
        if query_result["model"]:
            files = self._flatten_edges(
                query_result["model"]["bucket"]["files"])
            return files[0] if len(files) > 0 and files[0].get(
                "updatedAt") else None
        else:
            return None

    @normalize_exceptions
    def download_file(self, url):
        """Initiate a streaming download

        Arguments:
            url (str): The url to download

        Returns:
            A tuple of the content length and the streaming response
        """
        response = requests.get(url, stream=True)
        response.raise_for_status()
        return (int(response.headers.get("content-length", 0)), response)

    @normalize_exceptions
    def download_write_file(self, metadata, out_dir=None):
        """Download a file from a run and write it to wandb/

        Arguments:
            metadata (obj): The metadata object for the file to download. Comes from Api.download_urls().

        Returns:
            A tuple of the file's local path and the streaming response. The streaming response is None if the file already existed and was up to date.
        """
        file_name = metadata["name"]
        path = os.path.join(out_dir or self.settings("wandb_dir"), file_name)
        if self.file_current(file_name, metadata["md5"]):
            return path, None

        size, response = self.download_file(metadata["url"])

        with open(path, "wb") as file:
            for data in response.iter_content(chunk_size=1024):
                file.write(data)

        return path, response

    @normalize_exceptions
    def register_agent(self,
                       host,
                       sweep_id=None,
                       project_name=None,
                       entity=None):
        """Register a new agent

        Arguments:
            host (str): hostname
            persistent (bool): long running or oneoff
            sweep (str): sweep id
            project_name: (str): model that contains sweep
        """
        mutation = gql("""
        mutation CreateAgent(
            $host: String!
            $projectName: String!,
            $entityName: String!,
            $sweep: String!
        ) {
            createAgent(input: {
                host: $host,
                projectName: $projectName,
                entityName: $entityName,
                sweep: $sweep,
            }) {
                agent {
                    id
                }
            }
        }
        """)
        if entity is None:
            entity = self.settings("entity")
        if project_name is None:
            project_name = self.settings("project")

        # don't retry on validation or not found errors
        def no_retry_4xx(e):
            if not isinstance(e, requests.HTTPError):
                return True
            if not (e.response.status_code >= 400
                    and e.response.status_code < 500):
                return True
            body = json.loads(e.response.content)
            raise UsageError(body["errors"][0]["message"])

        response = self.gql(
            mutation,
            variable_values={
                "host": host,
                "entityName": entity,
                "projectName": project_name,
                "sweep": sweep_id,
            },
            check_retry_fn=no_retry_4xx,
        )
        return response["createAgent"]["agent"]

    def agent_heartbeat(self, agent_id, metrics, run_states):
        """Notify server about agent state, receive commands.

        Arguments:
            agent_id (str): agent_id
            metrics (dict): system metrics
            run_states (dict): run_id: state mapping
        Returns:
            List of commands to execute.
        """
        mutation = gql("""
        mutation Heartbeat(
            $id: ID!,
            $metrics: JSONString,
            $runState: JSONString
        ) {
            agentHeartbeat(input: {
                id: $id,
                metrics: $metrics,
                runState: $runState
            }) {
                agent {
                    id
                }
                commands
            }
        }
        """)
        try:
            response = self.gql(
                mutation,
                variable_values={
                    "id": agent_id,
                    "metrics": json.dumps(metrics),
                    "runState": json.dumps(run_states),
                },
            )
        except Exception as e:
            # GQL raises exceptions with stringified python dictionaries :/
            message = ast.literal_eval(e.args[0])["message"]
            logger.error("Error communicating with W&B: %s", message)
            return []
        else:
            return json.loads(response["agentHeartbeat"]["commands"])

    @normalize_exceptions
    def upsert_sweep(
        self,
        config,
        controller=None,
        scheduler=None,
        obj_id=None,
        project=None,
        entity=None,
    ):
        """Upsert a sweep object.

        Arguments:
            config (str): sweep config (will be converted to yaml)
        """
        project_query = """
                    project {
                        id
                        name
                        entity {
                            id
                            name
                        }
                    }
        """
        mutation_str = """
        mutation UpsertSweep(
            $id: ID,
            $config: String,
            $description: String,
            $entityName: String!,
            $projectName: String!,
            $controller: JSONString,
            $scheduler: JSONString
        ) {
            upsertSweep(input: {
                id: $id,
                config: $config,
                description: $description,
                entityName: $entityName,
                projectName: $projectName,
                controller: $controller,
                scheduler: $scheduler
            }) {
                sweep {
                    name
                    _PROJECT_QUERY_
                }
            }
        }
        """
        # TODO(jhr): we need protocol versioning to know schema is not supported
        # for now we will just try both new and old query
        mutation_new = gql(
            mutation_str.replace("_PROJECT_QUERY_", project_query))
        mutation_old = gql(mutation_str.replace("_PROJECT_QUERY_", ""))

        # don't retry on validation errors
        # TODO(jhr): generalize error handling routines
        def no_retry_4xx(e):
            if not isinstance(e, requests.HTTPError):
                return True
            if not (e.response.status_code >= 400
                    and e.response.status_code < 500):
                return True
            body = json.loads(e.response.content)
            raise UsageError(body["errors"][0]["message"])

        for mutation in mutation_new, mutation_old:
            try:
                response = self.gql(
                    mutation,
                    variable_values={
                        "id": obj_id,
                        "config": yaml.dump(config),
                        "description": config.get("description"),
                        "entityName": entity or self.settings("entity"),
                        "projectName": project or self.settings("project"),
                        "controller": controller,
                        "scheduler": scheduler,
                    },
                    check_retry_fn=no_retry_4xx,
                )
            except UsageError as e:
                raise (e)
            except Exception as e:
                # graphql schema exception is generic
                err = e
                continue
            err = None
            break
        if err:
            raise (err)

        sweep = response["upsertSweep"]["sweep"]
        project = sweep.get("project")
        if project:
            self.set_setting("project", project["name"])
            entity = project.get("entity")
            if entity:
                self.set_setting("entity", entity["name"])

        return response["upsertSweep"]["sweep"]["name"]

    @normalize_exceptions
    def create_anonymous_api_key(self):
        """Creates a new API key belonging to a new anonymous user."""
        mutation = gql("""
        mutation CreateAnonymousApiKey {
            createAnonymousEntity(input: {}) {
                apiKey {
                    name
                }
            }
        }
        """)

        response = self.gql(mutation, variable_values={})
        return response["createAnonymousEntity"]["apiKey"]["name"]

    def file_current(self, fname, md5):
        """Checksum a file and compare the md5 with the known md5"""
        return os.path.isfile(fname) and util.md5_file(fname) == md5

    @normalize_exceptions
    def pull(self, project, run=None, entity=None):
        """Download files from W&B

        Arguments:
            project (str): The project to download
            run (str): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            The requests library response object
        """
        project, run = self.parse_slug(project, run=run)
        assert run, "run must be specified"
        urls = self.download_urls(project, run, entity)
        responses = []
        for file_name in urls:
            _, response = self.download_write_file(urls[file_name])
            if response:
                responses.append(response)

        return responses

    def get_project(self):
        return self.settings("project")

    def _status_request(self, url, length):
        """Ask google how much we've uploaded"""
        return requests.put(
            url=url,
            headers={
                "Content-Length": "0",
                "Content-Range": "bytes */%i" % length
            },
        )

    def _flatten_edges(self, response):
        """Return an array from the nested graphql relay structure"""
        return [node["node"] for node in response["edges"]]
ipdb.set_trace()  # noqa
# fmt: on
query = gql("""
{
    enrichment(domain: "airbnb.com"){
        domain
        companyId
        companyName
        companyTools {
            count
            pageInfo {
                hasNextPage
                endCursor
            }
            edges {
                node {
                    tool{
                    id
                    name
                    }
                    sourcesSummary
                    sources
                }
            }
        }
    }
}
""")

print(client.execute(query))
Пример #11
0
class SpeckleClient:
    DEFAULT_HOST = "speckle.xyz"
    USE_SSL = True

    def __init__(self,
                 host: str = DEFAULT_HOST,
                 use_ssl: bool = USE_SSL) -> None:
        ws_protocol = "ws"
        http_protocol = "http"

        if use_ssl:
            ws_protocol = "wss"
            http_protocol = "https"

        # sanitise host input by removing protocol and trailing slash
        host = re.sub(r"((^\w+:|^)\/\/)|(\/$)", "", host)

        self.url = f"{http_protocol}://{host}"
        self.graphql = self.url + "/graphql"
        self.ws_url = f"{ws_protocol}://{host}/graphql"
        self.me = None

        self.httpclient = Client(transport=RequestsHTTPTransport(
            url=self.graphql, verify=True, retries=3))
        self.wsclient = None

        self._init_resources()

    def authenticate(self, token: str) -> None:
        """Authenticate the client using a personal access token
        The token is saved in the client object and a synchronous GraphQL entrypoint is created

        Arguments:
            token {str} -- an api token
        """
        self.me = {"token": token}
        headers = {
            "Authorization": f"Bearer {self.me['token']}",
            "Content-Type": "application/json",
        }
        httptransport = RequestsHTTPTransport(url=self.graphql,
                                              headers=headers,
                                              verify=True,
                                              retries=3)
        wstransport = WebsocketsTransport(
            url=self.ws_url,
            init_payload={"Authorization": f"Bearer {self.me['token']}"},
        )
        self.httpclient = Client(transport=httptransport)
        self.wsclient = Client(transport=wstransport)

        self._init_resources()

    def execute_query(self, query: str) -> Dict:
        return self.httpclient.execute(query)

    def _init_resources(self) -> None:
        self.stream = stream.Resource(me=self.me,
                                      basepath=self.url,
                                      client=self.httpclient)
        self.commit = commit.Resource(me=self.me,
                                      basepath=self.url,
                                      client=self.httpclient)
        self.branch = branch.Resource(me=self.me,
                                      basepath=self.url,
                                      client=self.httpclient)
        self.object = object.Resource(me=self.me,
                                      basepath=self.url,
                                      client=self.httpclient)
        self.server = server.Resource(me=self.me,
                                      basepath=self.url,
                                      client=self.httpclient)
        self.user = user.Resource(me=self.me,
                                  basepath=self.url,
                                  client=self.httpclient)
        self.subscribe = subscriptions.Resource(
            me=self.me,
            basepath=self.ws_url,
            client=self.wsclient,
        )

    def __getattr__(self, name):
        try:
            attr = getattr(resources, name)
            return attr.Resource(me=self.me,
                                 basepath=self.url,
                                 client=self.httpclient)
        except:
            raise SpeckleException(
                f"Method {name} is not supported by the SpeckleClient class")
Пример #12
0
class Api(object):
    """W&B Internal Api wrapper

    Note:
        Settings are automatically overridden by looking for
        a `wandb/settings` file in the current working directory or it's parent
        directory.  If none can be found, we look in the current users home
        directory.

    Args:
        default_settings(:obj:`dict`, optional): If you aren't using a settings
        file or you wish to override the section to use in the settings file
        Override the settings here.
    """

    HTTP_TIMEOUT = env.get_http_timeout(10)

    def __init__(
        self,
        default_settings=None,
        load_settings=True,
        retry_timedelta=None,
        environ=os.environ,
    ):
        if retry_timedelta is None:
            retry_timedelta = datetime.timedelta(days=1)
        self._environ = environ
        self.default_settings = {
            "section": "default",
            "git_remote": "origin",
            "ignore_globs": [],
            "base_url": "https://api.wandb.ai",
        }
        self.retry_timedelta = retry_timedelta
        self.default_settings.update(default_settings or {})
        self.retry_uploads = 10
        self._settings = Settings(
            load_settings=load_settings,
            root_dir=self.default_settings.get("root_dir"))
        # self.git = GitRepo(remote=self.settings("git_remote"))
        self.git = None
        # Mutable settings set by the _file_stream_api
        self.dynamic_settings = {
            "system_sample_seconds": 2,
            "system_samples": 15,
            "heartbeat_seconds": 30,
        }
        self.client = Client(transport=RequestsHTTPTransport(
            headers={
                "User-Agent": self.user_agent,
                "X-WANDB-USERNAME": env.get_username(env=self._environ),
                "X-WANDB-USER-EMAIL": env.get_user_email(env=self._environ),
            },
            use_json=True,
            # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
            # https://bugs.python.org/issue22889
            timeout=self.HTTP_TIMEOUT,
            auth=("api", self.api_key or ""),
            url="%s/graphql" % self.settings("base_url"),
        ))
        self.gql = retry.Retry(
            self.execute,
            retry_timedelta=retry_timedelta,
            check_retry_fn=util.no_retry_auth,
            retryable_exceptions=(RetryError, requests.RequestException),
        )
        self._current_run_id = None
        self._file_stream_api = None

    def reauth(self):
        """Ensures the current api key is set in the transport"""
        self.client.transport.auth = ("api", self.api_key or "")

    def relocate(self):
        """Ensures the current api points to the right server"""
        self.client.transport.url = "%s/graphql" % self.settings("base_url")

    def execute(self, *args, **kwargs):
        """Wrapper around execute that logs in cases of failure."""
        try:
            return self.client.execute(*args, **kwargs)
        except requests.exceptions.HTTPError as err:
            res = err.response
            logger.error("%s response executing GraphQL." % res.status_code)
            logger.error(res.text)
            self.display_gorilla_error_if_found(res)
            six.reraise(*sys.exc_info())

    def display_gorilla_error_if_found(self, res):
        try:
            data = res.json()
        except ValueError:
            return

        if "errors" in data and isinstance(data["errors"], list):
            for err in data["errors"]:
                if not err.get("message"):
                    continue
                wandb.termerror("Error while calling W&B API: {} ({})".format(
                    err["message"], res))

    def disabled(self):
        return self._settings.get(Settings.DEFAULT_SECTION,
                                  "disabled",
                                  fallback=False)

    def sync_spell(self, run, env=None):
        """Syncs this run with spell"""
        try:
            env = env or os.environ
            run.config["_wandb"]["spell_url"] = env.get("SPELL_RUN_URL")
            run.config.persist()
            try:
                url = run.get_url()
            except CommError as e:
                wandb.termerror("Unable to register run with spell.run: %s" %
                                str(e))
                return False
            return requests.put(
                env.get("SPELL_API_URL", "https://api.spell.run") +
                "/wandb_url",
                json={
                    "access_token": env.get("WANDB_ACCESS_TOKEN"),
                    "url": url
                },
                timeout=2,
            )
        except requests.RequestException:
            return False

    def save_patches(self, out_dir):
        """Save the current state of this repository to one or more patches.

        Makes one patch against HEAD and another one against the most recent
        commit that occurs in an upstream branch. This way we can be robust
        to history editing as long as the user never does "push -f" to break
        history on an upstream branch.

        Writes the first patch to <out_dir>/<DIFF_FNAME> and the second to
        <out_dir>/upstream_diff_<commit_id>.patch.

        Args:
            out_dir (str): Directory to write the patch files.
        """
        if not self.git.enabled:
            return False

        try:
            root = self.git.root
            if self.git.dirty:
                patch_path = os.path.join(out_dir,
                                          wandb_lib.filenames.DIFF_FNAME)
                if self.git.has_submodule_diff:
                    with open(patch_path, "wb") as patch:
                        # we diff against HEAD to ensure we get changes in the index
                        subprocess.check_call(
                            ["git", "diff", "--submodule=diff", "HEAD"],
                            stdout=patch,
                            cwd=root,
                            timeout=5,
                        )
                else:
                    with open(patch_path, "wb") as patch:
                        subprocess.check_call(["git", "diff", "HEAD"],
                                              stdout=patch,
                                              cwd=root,
                                              timeout=5)

            upstream_commit = self.git.get_upstream_fork_point()
            if upstream_commit and upstream_commit != self.git.repo.head.commit:
                sha = upstream_commit.hexsha
                upstream_patch_path = os.path.join(
                    out_dir, "upstream_diff_{}.patch".format(sha))
                if self.git.has_submodule_diff:
                    with open(upstream_patch_path, "wb") as upstream_patch:
                        subprocess.check_call(
                            ["git", "diff", "--submodule=diff", sha],
                            stdout=upstream_patch,
                            cwd=root,
                            timeout=5,
                        )
                else:
                    with open(upstream_patch_path, "wb") as upstream_patch:
                        subprocess.check_call(
                            ["git", "diff", sha],
                            stdout=upstream_patch,
                            cwd=root,
                            timeout=5,
                        )
        # TODO: A customer saw `ValueError: Reference at 'refs/remotes/origin/foo' does not exist`
        # so we now catch ValueError.  Catching this error feels too generic.
        except (
                ValueError,
                subprocess.CalledProcessError,
                subprocess.TimeoutExpired,
        ) as e:
            logger.error("Error generating diff: %s" % e)

    def set_current_run_id(self, run_id):
        self._current_run_id = run_id

    @property
    def current_run_id(self):
        return self._current_run_id

    @property
    def user_agent(self):
        return "W&B Internal Client %s" % __version__

    @property
    def api_key(self):
        auth = requests.utils.get_netrc_auth(self.api_url)
        key = None
        if auth:
            key = auth[-1]
        # Environment should take precedence
        if self._environ.get(env.API_KEY):
            key = self._environ.get(env.API_KEY)
        return key

    @property
    def api_url(self):
        return self.settings("base_url")

    @property
    def app_url(self):
        return wandb.util.app_url(self.api_url)

    def settings(self, key=None, section=None):
        """The settings overridden from the wandb/settings file.

        Args:
            key (str, optional): If provided only this setting is returned
            section (str, optional): If provided this section of the setting file is
            used, defaults to "default"

        Returns:
            A dict with the current settings

                {
                    "entity": "models",
                    "base_url": "https://api.wandb.ai",
                    "project": None
                }
        """
        result = self.default_settings.copy()
        result.update(self._settings.items(section=section))
        result.update({
            "entity":
            env.get_entity(
                self._settings.get(
                    Settings.DEFAULT_SECTION,
                    "entity",
                    fallback=result.get("entity"),
                ),
                env=self._environ,
            ),
            "project":
            env.get_project(
                self._settings.get(
                    Settings.DEFAULT_SECTION,
                    "project",
                    fallback=result.get("project"),
                ),
                env=self._environ,
            ),
            "base_url":
            env.get_base_url(
                self._settings.get(
                    Settings.DEFAULT_SECTION,
                    "base_url",
                    fallback=result.get("base_url"),
                ),
                env=self._environ,
            ),
            "ignore_globs":
            env.get_ignore(
                self._settings.get(
                    Settings.DEFAULT_SECTION,
                    "ignore_globs",
                    fallback=result.get("ignore_globs"),
                ),
                env=self._environ,
            ),
        })

        return result if key is None else result[key]

    def clear_setting(self, key, globally=False, persist=False):
        self._settings.clear(Settings.DEFAULT_SECTION,
                             key,
                             globally=globally,
                             persist=persist)

    def set_setting(self, key, value, globally=False, persist=False):
        self._settings.set(Settings.DEFAULT_SECTION,
                           key,
                           value,
                           globally=globally,
                           persist=persist)
        if key == "entity":
            env.set_entity(value, env=self._environ)
        elif key == "project":
            env.set_project(value, env=self._environ)
        elif key == "base_url":
            self.relocate()

    def parse_slug(self, slug, project=None, run=None):
        if slug and "/" in slug:
            parts = slug.split("/")
            project = parts[0]
            run = parts[1]
        else:
            project = project or self.settings().get("project")
            if project is None:
                raise CommError("No default project configured.")
            run = run or slug or env.get_run(env=self._environ)
            if run is None:
                run = "latest"
        return (project, run)

    @normalize_exceptions
    def viewer(self):
        query = gql("""
        query Viewer{
            viewer {
                id
                entity
                teams {
                    edges {
                        node {
                            name
                        }
                    }
                }
            }
        }
        """)
        res = self.gql(query)
        return res.get("viewer") or {}

    @normalize_exceptions
    def list_projects(self, entity=None):
        """Lists projects in W&B scoped by entity.

        Args:
            entity (str, optional): The entity to scope this project to.

        Returns:
                [{"id","name","description"}]
        """
        query = gql("""
        query Models($entity: String!) {
            models(first: 10, entityName: $entity) {
                edges {
                    node {
                        id
                        name
                        description
                    }
                }
            }
        }
        """)
        return self._flatten_edges(
            self.gql(
                query,
                variable_values={"entity": entity
                                 or self.settings("entity")})["models"])

    @normalize_exceptions
    def project(self, project, entity=None):
        """Retrive project

        Args:
            project (str): The project to get details for
            entity (str, optional): The entity to scope this project to.

        Returns:
                [{"id","name","repo","dockerImage","description"}]
        """
        query = gql("""
        query Models($entity: String, $project: String!) {
            model(name: $project, entityName: $entity) {
                id
                name
                repo
                dockerImage
                description
            }
        }
        """)
        return self.gql(query,
                        variable_values={
                            "entity": entity,
                            "project": project
                        })["model"]

    @normalize_exceptions
    def sweep(self, sweep, specs, project=None, entity=None):
        """Retrieve sweep.

        Args:
            sweep (str): The sweep to get details for
            specs (str): history specs
            project (str, optional): The project to scope this sweep to.
            entity (str, optional): The entity to scope this sweep to.

        Returns:
                [{"id","name","repo","dockerImage","description"}]
        """
        query = gql("""
        query Models($entity: String, $project: String!, $sweep: String!, $specs: [JSONString!]!) {
            model(name: $project, entityName: $entity) {
                sweep(sweepName: $sweep) {
                    id
                    name
                    method
                    state
                    description
                    config
                    createdAt
                    heartbeatAt
                    updatedAt
                    earlyStopJobRunning
                    bestLoss
                    controller
                    scheduler
                    runs {
                        edges {
                            node {
                                name
                                state
                                config
                                exitcode
                                heartbeatAt
                                shouldStop
                                failed
                                stopped
                                running
                                summaryMetrics
                                sampledHistory(specs: $specs)
                            }
                        }
                    }
                }
            }
        }
        """)
        entity = entity or self.settings("entity")
        project = project or self.settings("project")
        response = self.gql(
            query,
            variable_values={
                "entity": entity,
                "project": project,
                "sweep": sweep,
                "specs": specs,
            },
        )
        if response["model"] is None or response["model"]["sweep"] is None:
            raise ValueError("Sweep {}/{}/{} not found".format(
                entity, project, sweep))
        data = response["model"]["sweep"]
        if data:
            data["runs"] = self._flatten_edges(data["runs"])
        return data

    @normalize_exceptions
    def list_runs(self, project, entity=None):
        """Lists runs in W&B scoped by project.

        Args:
            project (str): The project to scope the runs to
            entity (str, optional): The entity to scope this project to.  Defaults to public models

        Returns:
                [{"id",name","description"}]
        """
        query = gql("""
        query Buckets($model: String!, $entity: String!) {
            model(name: $model, entityName: $entity) {
                buckets(first: 10) {
                    edges {
                        node {
                            id
                            name
                            displayName
                            description
                        }
                    }
                }
            }
        }
        """)
        return self._flatten_edges(
            self.gql(
                query,
                variable_values={
                    "entity": entity or self.settings("entity"),
                    "model": project or self.settings("project"),
                },
            )["model"]["buckets"])

    @normalize_exceptions
    def launch_run(self, command, project=None, entity=None, run_id=None):
        """Launch a run in the cloud.

        Args:
            command (str): The command to run
            program (str): The file to run
            project (str): The project to scope the runs to
            entity (str, optional): The entity to scope this project to.  Defaults to public models
            run_id (str, optional): The run_id to scope to

        Returns:
                [{"podName","status"}]
        """
        query = gql("""
        mutation launchRun(
            $entity: String
            $model: String
            $runId: String
            $image: String
            $command: String
            $patch: String
            $cwd: String
            $datasets: [String]
        ) {
            launchRun(input: {id: $runId, entityName: $entity, patch: $patch, modelName: $model,
                image: $image, command: $command, datasets: $datasets, cwd: $cwd}) {
                podName
                status
                runId
            }
        }
        """)
        run_id = run_id or self.current_run_id
        assert run_id, "run_id must be specified"
        patch = BytesIO()
        if self.git.dirty:
            self.git.repo.git.execute(["git", "diff"], output_stream=patch)
            patch.seek(0)
        cwd = "."
        if self.git.enabled:
            cwd += os.getcwd().replace(self.git.repo.working_dir, "")
        return self.gql(
            query,
            variable_values={
                "entity": entity or self.settings("entity"),
                "model": project or self.settings("project"),
                "command": command,
                "runId": run_id,
                "patch": patch.read().decode("utf8"),
                "cwd": cwd,
            },
        )

    @normalize_exceptions
    def run_config(self, project, run=None, entity=None):
        """Get the relevant configs for a run

        Args:
            project (str): The project to download, (can include bucket)
            run (str): The run to download
            entity (str, optional): The entity to scope this project to.
        """
        query = gql("""
        query Model($name: String!, $entity: String!, $run: String!) {
            model(name: $name, entityName: $entity) {
                bucket(name: $run) {
                    config
                    commit
                    patch
                    files(names: ["wandb-metadata.json"]) {
                        edges {
                            node {
                                url
                            }
                        }
                    }
                }
            }
        }
        """)
        run = run or self.current_run_id
        assert run, "run must be specified"
        response = self.gql(query,
                            variable_values={
                                "name": project,
                                "run": run,
                                "entity": entity
                            })
        if response["model"] is None:
            raise ValueError("Run {}/{}/{} not found".format(
                entity, project, run))
        run = response["model"]["bucket"]
        commit = run["commit"]
        patch = run["patch"]
        config = json.loads(run["config"] or "{}")
        if len(run["files"]["edges"]) > 0:
            url = run["files"]["edges"][0]["node"]["url"]
            res = requests.get(url)
            res.raise_for_status()
            metadata = res.json()
        else:
            metadata = {}
        return (commit, config, patch, metadata)

    @normalize_exceptions
    def run_resume_status(self, entity, project_name, name):
        """Check if a run exists and get resume information.

        Args:
            entity (str, optional): The entity to scope this project to.
            project_name (str): The project to download, (can include bucket)
            name (str): The run to download
        """
        query = gql("""
        query Model($project: String!, $entity: String, $name: String!) {
            model(name: $project, entityName: $entity) {
                id
                name
                entity {
                    id
                    name
                }

                bucket(name: $name, missingOk: true) {
                    id
                    name
                    summaryMetrics
                    displayName
                    logLineCount
                    historyLineCount
                    eventsLineCount
                    historyTail
                    eventsTail
                    config
                }
            }
        }
        """)

        response = self.gql(
            query,
            variable_values={
                "entity": entity,
                "project": project_name,
                "name": name,
            },
        )

        if "model" not in response or "bucket" not in (response["model"]
                                                       or {}):
            return None

        project = response["model"]
        self.set_setting("project", project_name)
        if "entity" in project:
            self.set_setting("entity", project["entity"]["name"])

        return project["bucket"]

    @normalize_exceptions
    def check_stop_requested(self, project_name, entity_name, run_id):
        query = gql("""
        query Model($projectName: String, $entityName: String, $runId: String!) {
            project(name:$projectName, entityName:$entityName) {
                run(name:$runId) {
                    stopped
                }
            }
        }
        """)
        run_id = run_id or self.current_run_id
        assert run_id, "run_id must be specified"
        response = self.gql(
            query,
            variable_values={
                "projectName": project_name,
                "entityName": entity_name,
                "runId": run_id,
            },
        )

        project = response.get("project", None)
        if not project:
            return False
        run = project.get("run", None)
        if not run:
            return False

        return run["stopped"]

    def format_project(self, project):
        return re.sub(r"\W+", "-", project.lower()).strip("-_")

    @normalize_exceptions
    def upsert_project(self, project, id=None, description=None, entity=None):
        """Create a new project

        Args:
            project (str): The project to create
            description (str, optional): A description of this project
            entity (str, optional): The entity to scope this project to.
        """
        mutation = gql("""
        mutation UpsertModel($name: String!, $id: String, $entity: String!, $description: String, $repo: String)  {
            upsertModel(input: { id: $id, name: $name, entityName: $entity, description: $description, repo: $repo }) {
                model {
                    name
                    description
                }
            }
        }
        """)
        response = self.gql(
            mutation,
            variable_values={
                "name": self.format_project(project),
                "entity": entity or self.settings("entity"),
                "description": description,
                "repo": self.git.remote_url,
                "id": id,
            },
        )
        return response["upsertModel"]["model"]

    @normalize_exceptions
    def pop_from_run_queue(self, entity=None, project=None):
        mutation = gql("""
        mutation popFromRunQueue($entity: String!, $project: String!)  {
            popFromRunQueue(input: { entityName: $entity, projectName: $project }) {
                runQueueItemId
                runSpec
            }
        }
        """)
        response = self.gql(mutation,
                            variable_values={
                                "entity": entity,
                                "project": project
                            })
        return response["popFromRunQueue"]

    @normalize_exceptions
    def upsert_run(
        self,
        id=None,
        name=None,
        project=None,
        host=None,
        group=None,
        tags=None,
        config=None,
        description=None,
        entity=None,
        state=None,
        display_name=None,
        notes=None,
        repo=None,
        job_type=None,
        program_path=None,
        commit=None,
        sweep_name=None,
        summary_metrics=None,
        num_retries=None,
    ):
        """Update a run

        Args:
            id (str, optional): The existing run to update
            name (str, optional): The name of the run to create
            group (str, optional): Name of the group this run is a part of
            project (str, optional): The name of the project
            config (dict, optional): The latest config params
            description (str, optional): A description of this project
            entity (str, optional): The entity to scope this project to.
            repo (str, optional): Url of the program's repository.
            state (str, optional): State of the program.
            job_type (str, optional): Type of job, e.g 'train'.
            program_path (str, optional): Path to the program.
            commit (str, optional): The Git SHA to associate the run with
            summary_metrics (str, optional): The JSON summary metrics
        """
        mutation = gql("""
        mutation UpsertBucket(
            $id: String, $name: String,
            $project: String,
            $entity: String!,
            $groupName: String,
            $description: String,
            $displayName: String,
            $notes: String,
            $commit: String,
            $config: JSONString,
            $host: String,
            $debug: Boolean,
            $program: String,
            $repo: String,
            $jobType: String,
            $state: String,
            $sweep: String,
            $tags: [String!],
            $summaryMetrics: JSONString,
        ) {
            upsertBucket(input: {
                id: $id,
                name: $name,
                groupName: $groupName,
                modelName: $project,
                entityName: $entity,
                description: $description,
                displayName: $displayName,
                notes: $notes,
                config: $config,
                commit: $commit,
                host: $host,
                debug: $debug,
                jobProgram: $program,
                jobRepo: $repo,
                jobType: $jobType,
                state: $state,
                sweep: $sweep,
                tags: $tags,
                summaryMetrics: $summaryMetrics,
            }) {
                bucket {
                    id
                    name
                    displayName
                    description
                    config
                    project {
                        id
                        name
                        entity {
                            id
                            name
                        }
                    }
                }
            }
        }
        """)
        if config is not None:
            config = json.dumps(config)
        if not description or description.isspace():
            description = None

        kwargs = {}
        if num_retries is not None:
            kwargs["num_retries"] = num_retries

        variable_values = {
            "id": id,
            "entity": entity or self.settings("entity"),
            "name": name,
            "project": project,
            "groupName": group,
            "tags": tags,
            "description": description,
            "config": config,
            "commit": commit,
            "displayName": display_name,
            "notes": notes,
            "host":
            None if self.settings().get("anonymous") == "true" else host,
            "debug": env.is_debug(env=self._environ),
            "repo": repo,
            "program": program_path,
            "jobType": job_type,
            "state": state,
            "sweep": sweep_name,
            "summaryMetrics": summary_metrics,
        }

        response = self.gql(mutation,
                            variable_values=variable_values,
                            **kwargs)

        run = response["upsertBucket"]["bucket"]
        if project := run.get("project"):
            self.set_setting("project", project["name"])
            if entity := project.get("entity"):
                self.set_setting("entity", entity["name"])
Пример #13
0
class Graph():
    def __init__(self, url: str) -> None:
        """
        - May raise requests.RequestException if there is a problem connecting to the subgraph"""
        transport = RequestsHTTPTransport(url=url)
        try:
            self.client = Client(transport=transport,
                                 fetch_schema_from_transport=False)
        except (requests.exceptions.RequestException) as e:
            raise RemoteError(
                f'Failed to connect to the graph at {url} due to {str(e)}'
            ) from e

    def query(
        self,
        querystr: str,
        param_types: Optional[Dict[str, Any]] = None,
        param_values: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        """Queries The Graph for a particular query

        May raise:
        - RemoteError: If there is a problem querying the subgraph and there
        are no retries left.
        """
        prefix = ''
        if param_types is not None:
            prefix = 'query '
            prefix += json.dumps(param_types).replace('"', '').replace(
                '{', '(').replace('}', ')')
            prefix += '{'

        querystr = prefix + querystr
        log.debug(f'Querying The Graph for {querystr}')

        retries_left = QUERY_RETRY_TIMES
        while retries_left > 0:
            try:
                result = self.client.execute(gql(querystr),
                                             variable_values=param_values)
            # need to catch Exception here due to stupidity of gql library
            except (requests.exceptions.RequestException, Exception) as e:  # pylint: disable=broad-except  # noqa: E501
                # NB: the lack of a good API error handling by The Graph combined
                # with gql v2 raising bare exceptions doesn't allow us to act
                # better on failed requests. Currently all trigger the retry logic.
                # TODO: upgrade to gql v3 and amend this code on any improvement
                # The Graph does on its API error handling.
                exc_msg = str(e)
                retries_left -= 1
                base_msg = f'The Graph query to {querystr} failed due to {exc_msg}'
                if retries_left:
                    sleep_seconds = RETRY_BACKOFF_FACTOR * pow(
                        2, QUERY_RETRY_TIMES - retries_left)
                    retry_msg = (
                        f'Retrying query after {sleep_seconds} seconds. '
                        f'Retries left: {retries_left}.')
                    log.error(f'{base_msg}. {retry_msg}')
                    gevent.sleep(sleep_seconds)
                else:
                    raise RemoteError(f'{base_msg}. No retries left.') from e
            else:
                break

        log.debug('Got result from The Graph query')
        return result
Пример #14
0
class DagsterGraphQLClient:
    """Official Dagster Python Client for GraphQL

    Utilizes the gql library to dispatch queries over HTTP to a remote Dagster GraphQL Server

    As of now, all operations on this client are synchronous.

    Intended usage:

    .. code-block:: python

        client = DagsterGraphQLClient("localhost", port_number=3000)
        status = client.get_run_status(**SOME_RUN_ID**)

    Args:
        hostname (str): Hostname for the Dagster GraphQL API, like `localhost` or
            `dagit.dagster.YOUR_ORG_HERE`.
        port_number (Optional[int], optional): Optional port number to connect to on the host.
            Defaults to None.
        transport (Optional[Transport], optional): A custom transport to use to connect to the
            GraphQL API with (e.g. for custom auth). Defaults to None.
        use_https (bool, optional): Whether to use https in the URL connection string for the
            GraphQL API. Defaults to False.

    Raises:
        :py:class:`~requests.exceptions.ConnectionError`: if the client cannot connect to the host.
    """
    def __init__(
        self,
        hostname: str,
        port_number: Optional[int] = None,
        transport: Optional[Transport] = None,
        use_https: bool = False,
    ):
        experimental_class_warning(self.__class__.__name__)

        self._hostname = check.str_param(hostname, "hostname")
        self._port_number = check.opt_int_param(port_number, "port_number")
        self._use_https = check.bool_param(use_https, "use_https")

        self._url = (("https://" if self._use_https else "http://") +
                     (f"{self._hostname}:{self._port_number}"
                      if self._port_number else self._hostname) + "/graphql")

        self._transport = check.opt_inst_param(
            transport,
            "transport",
            Transport,
            default=RequestsHTTPTransport(url=self._url, use_json=True),
        )
        try:
            self._client = Client(transport=self._transport,
                                  fetch_schema_from_transport=True)
        except requests.exceptions.ConnectionError as exc:
            raise DagsterGraphQLClientError(
                f"Error when connecting to url {self._url}. " +
                f"Did you specify hostname: {self._hostname} " +
                (f"and port_number: {self._port_number} " if self.
                 _port_number else "") + "correctly?") from exc

    def _execute(self, query: str, variables: Optional[Dict[str, Any]] = None):
        try:
            return self._client.execute(gql(query), variable_values=variables)
        except Exception as exc:  # catch generic Exception from the gql client
            raise DagsterGraphQLClientError(
                f"Exception occured during execution of query \n{query}\n with variables \n{variables}\n"
            ) from exc

    def _get_repo_locations_and_names_with_pipeline(
            self, pipeline_name: str) -> List[PipelineInfo]:
        res_data = self._execute(
            CLIENT_GET_REPO_LOCATIONS_NAMES_AND_PIPELINES_QUERY)
        query_res = res_data["repositoriesOrError"]
        repo_connection_status = query_res["__typename"]
        if repo_connection_status == "RepositoryConnection":
            valid_nodes: Iterable[PipelineInfo] = chain(
                *map(PipelineInfo.from_node, query_res["nodes"]))
            return [
                info for info in valid_nodes
                if info.pipeline_name == pipeline_name
            ]
        else:
            raise DagsterGraphQLClientError(repo_connection_status,
                                            query_res["message"])

    def _core_submit_execution(
        self,
        pipeline_name: str,
        repository_location_name: Optional[str] = None,
        repository_name: Optional[str] = None,
        run_config: Optional[Any] = None,
        mode: Optional[str] = None,
        preset: Optional[str] = None,
        tags: Optional[Dict[str, Any]] = None,
        solid_selection: Optional[List[str]] = None,
        is_using_job_op_graph_apis: Optional[bool] = False,
    ):
        check.opt_str_param(repository_location_name,
                            "repository_location_name")
        check.opt_str_param(repository_name, "repository_name")
        check.str_param(pipeline_name, "pipeline_name")
        check.opt_str_param(mode, "mode")
        check.opt_str_param(preset, "preset")
        run_config = check.opt_dict_param(run_config, "run_config")

        # The following invariant will never fail when a job is executed
        check.invariant(
            (mode is not None and run_config is not None)
            or preset is not None,
            "Either a mode and run_config or a preset must be specified in order to "
            f"submit the pipeline {pipeline_name} for execution",
        )
        tags = validate_tags(tags)

        pipeline_or_job = "Job" if is_using_job_op_graph_apis else "Pipeline"

        if not repository_location_name or not repository_name:
            pipeline_info_lst = self._get_repo_locations_and_names_with_pipeline(
                pipeline_name)
            if len(pipeline_info_lst) == 0:
                raise DagsterGraphQLClientError(
                    f"{pipeline_or_job}NotFoundError",
                    f"No {'jobs' if is_using_job_op_graph_apis else 'pipelines'} with the name `{pipeline_name}` exist",
                )
            elif len(pipeline_info_lst) == 1:
                pipeline_info = pipeline_info_lst[0]
                repository_location_name = pipeline_info.repository_location_name
                repository_name = pipeline_info.repository_name
            else:
                raise DagsterGraphQLClientError(
                    "Must specify repository_location_name and repository_name"
                    f" since there are multiple {'jobs' if is_using_job_op_graph_apis else 'pipelines'} with the name {pipeline_name}."
                    f"\n\tchoose one of: {pipeline_info_lst}")

        variables: Dict[str, Any] = {
            "executionParams": {
                "selector": {
                    "repositoryLocationName": repository_location_name,
                    "repositoryName": repository_name,
                    "pipelineName": pipeline_name,
                    "solidSelection": solid_selection,
                }
            }
        }
        if preset is not None:
            variables["executionParams"]["preset"] = preset
        if mode is not None and run_config is not None:
            variables["executionParams"] = {
                **variables["executionParams"],
                "runConfigData": run_config,
                "mode": mode,
                "executionMetadata": {
                    "tags": [{
                        "key": k,
                        "value": v
                    } for k, v in tags.items()]
                } if tags else {},
            }

        res_data: Dict[str, Any] = self._execute(
            CLIENT_SUBMIT_PIPELINE_RUN_MUTATION, variables)
        query_result = res_data["launchPipelineExecution"]
        query_result_type = query_result["__typename"]
        if (query_result_type == "LaunchRunSuccess"
                or query_result_type == "LaunchPipelineRunSuccess"):
            return query_result["run"]["runId"]
        elif query_result_type == "InvalidStepError":
            raise DagsterGraphQLClientError(query_result_type,
                                            query_result["invalidStepKey"])
        elif query_result_type == "InvalidOutputError":
            error_info = InvalidOutputErrorInfo(
                step_key=query_result["stepKey"],
                invalid_output_name=query_result["invalidOutputName"],
            )
            raise DagsterGraphQLClientError(query_result_type, body=error_info)
        elif (query_result_type == "RunConfigValidationInvalid"
              or query_result_type == "PipelineConfigValidationInvalid"):
            raise DagsterGraphQLClientError(query_result_type,
                                            query_result["errors"])
        else:
            # query_result_type is a ConflictingExecutionParamsError, a PresetNotFoundError
            # a PipelineNotFoundError, a RunConflict, or a PythonError
            raise DagsterGraphQLClientError(query_result_type,
                                            query_result["message"])

    def submit_pipeline_execution(
        self,
        pipeline_name: str,
        repository_location_name: Optional[str] = None,
        repository_name: Optional[str] = None,
        run_config: Optional[Any] = None,
        mode: Optional[str] = None,
        preset: Optional[str] = None,
        tags: Optional[Dict[str, Any]] = None,
        solid_selection: Optional[List[str]] = None,
    ) -> str:
        """Submits a Pipeline with attached configuration for execution.

        Args:
            pipeline_name (str): The pipeline's name
            repository_location_name (Optional[str], optional): The name of the repository location where
                the pipeline is located. If omitted, the client will try to infer the repository location
                from the available options on the Dagster deployment. Defaults to None.
            repository_name (Optional[str], optional): The name of the repository where the pipeline is located.
                If omitted, the client will try to infer the repository from the available options
                on the Dagster deployment. Defaults to None.
            run_config (Optional[Any], optional): This is the run config to execute the pipeline with.
                Note that runConfigData is any-typed in the GraphQL type system. This type is used when passing in
                an arbitrary object for run config. However, it must conform to the constraints of the config
                schema for this pipeline. If it does not, the client will throw a DagsterGraphQLClientError with a message of
                RunConfigValidationInvalid. Defaults to None.
            mode (Optional[str], optional): The mode to run the pipeline with. If you have not
                defined any custom modes for your pipeline, the default mode is "default". Defaults to None.
            preset (Optional[str], optional): The name of a pre-defined preset to use instead of a
                run config. Defaults to None.
            tags (Optional[Dict[str, Any]], optional): A set of tags to add to the pipeline execution.

        Raises:
            DagsterGraphQLClientError("InvalidStepError", invalid_step_key): the pipeline has an invalid step
            DagsterGraphQLClientError("InvalidOutputError", body=error_object): some solid has an invalid output within the pipeline.
                The error_object is of type dagster_graphql.InvalidOutputErrorInfo.
            DagsterGraphQLClientError("ConflictingExecutionParamsError", invalid_step_key): a preset and a run_config & mode are present
                that conflict with one another
            DagsterGraphQLClientError("PresetNotFoundError", message): if the provided preset name is not found
            DagsterGraphQLClientError("RunConflict", message): a `DagsterRunConflict` occured during execution.
                This indicates that a conflicting pipeline run already exists in run storage.
            DagsterGraphQLClientError("PipelineConfigurationInvalid", invalid_step_key): the run_config is not in the expected format
                for the pipeline
            DagsterGraphQLClientError("PipelineNotFoundError", message): the requested pipeline does not exist
            DagsterGraphQLClientError("PythonError", message): an internal framework error occurred

        Returns:
            str: run id of the submitted pipeline run
        """
        return self._core_submit_execution(
            pipeline_name,
            repository_location_name,
            repository_name,
            run_config,
            mode,
            preset,
            tags,
            solid_selection,
            is_using_job_op_graph_apis=False,
        )

    def submit_job_execution(
        self,
        job_name: str,
        repository_location_name: Optional[str] = None,
        repository_name: Optional[str] = None,
        run_config: Optional[Dict[str, Any]] = None,
        tags: Optional[Dict[str, Any]] = None,
        op_selection: Optional[List[str]] = None,
    ) -> str:
        """Submits a job with attached configuration for execution.

        Args:
            job_name (str): The job's name
            repository_location_name (Optional[str]): The name of the repository location where
                the job is located. If omitted, the client will try to infer the repository location
                from the available options on the Dagster deployment. Defaults to None.
            repository_name (Optional[str]): The name of the repository where the job is located.
                If omitted, the client will try to infer the repository from the available options
                on the Dagster deployment. Defaults to None.
            run_config (Optional[Dict[str, Any]]): This is the run config to execute the job with.
                Note that runConfigData is any-typed in the GraphQL type system. This type is used when passing in
                an arbitrary object for run config. However, it must conform to the constraints of the config
                schema for this job. If it does not, the client will throw a DagsterGraphQLClientError with a message of
                JobConfigValidationInvalid. Defaults to None.
            tags (Optional[Dict[str, Any]]): A set of tags to add to the job execution.

        Raises:
            DagsterGraphQLClientError("InvalidStepError", invalid_step_key): the job has an invalid step
            DagsterGraphQLClientError("InvalidOutputError", body=error_object): some solid has an invalid output within the job.
                The error_object is of type dagster_graphql.InvalidOutputErrorInfo.
            DagsterGraphQLClientError("RunConflict", message): a `DagsterRunConflict` occured during execution.
                This indicates that a conflicting job run already exists in run storage.
            DagsterGraphQLClientError("PipelineConfigurationInvalid", invalid_step_key): the run_config is not in the expected format
                for the job
            DagsterGraphQLClientError("JobNotFoundError", message): the requested job does not exist
            DagsterGraphQLClientError("PythonError", message): an internal framework error occurred

        Returns:
            str: run id of the submitted pipeline run
        """
        return self._core_submit_execution(
            pipeline_name=job_name,
            repository_location_name=repository_location_name,
            repository_name=repository_name,
            run_config=run_config,
            mode="default",
            preset=None,
            tags=tags,
            solid_selection=op_selection,
            is_using_job_op_graph_apis=True,
        )

    def get_run_status(self, run_id: str) -> PipelineRunStatus:
        """Get the status of a given Pipeline Run

        Args:
            run_id (str): run id of the requested pipeline run.

        Raises:
            DagsterGraphQLClientError("PipelineNotFoundError", message): if the requested run id is not found
            DagsterGraphQLClientError("PythonError", message): on internal framework errors

        Returns:
            PipelineRunStatus: returns a status Enum describing the state of the requested pipeline run
        """
        check.str_param(run_id, "run_id")

        res_data: Dict[str, Dict[str, Any]] = self._execute(
            GET_PIPELINE_RUN_STATUS_QUERY, {"runId": run_id})
        query_result: Dict[str, Any] = res_data["pipelineRunOrError"]
        query_result_type: str = query_result["__typename"]
        if query_result_type == "PipelineRun" or query_result_type == "Run":
            return PipelineRunStatus(query_result["status"])
        else:
            raise DagsterGraphQLClientError(query_result_type,
                                            query_result["message"])

    def reload_repository_location(
            self,
            repository_location_name: str) -> ReloadRepositoryLocationInfo:
        """Reloads a Dagster Repository Location, which reloads all repositories in that repository location.

        This is useful in a variety of contexts, including refreshing Dagit without restarting
        the server.

        Args:
            repository_location_name (str): The name of the repository location

        Returns:
            ReloadRepositoryLocationInfo: Object with information about the result of the reload request
        """
        check.str_param(repository_location_name, "repository_location_name")

        res_data: Dict[str, Dict[str, Any]] = self._execute(
            RELOAD_REPOSITORY_LOCATION_MUTATION,
            {"repositoryLocationName": repository_location_name},
        )

        query_result: Dict[str, Any] = res_data["reloadRepositoryLocation"]
        query_result_type: str = query_result["__typename"]
        if query_result_type == "WorkspaceLocationEntry":
            location_or_error_type = query_result["locationOrLoadError"][
                "__typename"]
            if location_or_error_type == "RepositoryLocation":
                return ReloadRepositoryLocationInfo(
                    status=ReloadRepositoryLocationStatus.SUCCESS)
            else:
                return ReloadRepositoryLocationInfo(
                    status=ReloadRepositoryLocationStatus.FAILURE,
                    failure_type="PythonError",
                    message=query_result["locationOrLoadError"]["message"],
                )
        else:
            # query_result_type is either ReloadNotSupported or RepositoryLocationNotFound
            return ReloadRepositoryLocationInfo(
                status=ReloadRepositoryLocationStatus.FAILURE,
                failure_type=query_result_type,
                message=query_result["message"],
            )

    def shutdown_repository_location(
            self,
            repository_location_name: str) -> ShutdownRepositoryLocationInfo:
        """Shuts down the server that is serving metadata for the provided repository location.

        This is primarily useful when you want the server to be restarted by the compute environment
        in which it is running (for example, in Kubernetes, the pod in which the server is running
        will automatically restart when the server is shut down, and the repository metadata will
        be reloaded)

        Args:
            repository_location_name (str): The name of the repository location

        Returns:
            ShutdownRepositoryLocationInfo: Object with information about the result of the reload request
        """
        check.str_param(repository_location_name, "repository_location_name")

        res_data: Dict[str, Dict[str, Any]] = self._execute(
            SHUTDOWN_REPOSITORY_LOCATION_MUTATION,
            {"repositoryLocationName": repository_location_name},
        )

        query_result: Dict[str, Any] = res_data["shutdownRepositoryLocation"]
        query_result_type: str = query_result["__typename"]
        if query_result_type == "ShutdownRepositoryLocationSuccess":
            return ShutdownRepositoryLocationInfo(
                status=ShutdownRepositoryLocationStatus.SUCCESS)
        elif (query_result_type == "RepositoryLocationNotFound"
              or query_result_type == "PythonError"):
            return ShutdownRepositoryLocationInfo(
                status=ShutdownRepositoryLocationStatus.FAILURE,
                message=query_result["message"],
            )
        else:
            raise Exception(
                f"Unexpected query result type {query_result_type}")
Пример #15
0
def gql_query(schema, query_str, **params):
    client = Client(schema=schema)
    query = gql(query_str)
    return client.execute(query, variable_values=params)
Пример #16
0
def execute(client: Client, query: Document, **kwargs) -> dict:
    response = client.execute(query, **kwargs)
    logger.debug("Response:\n%s", json.dumps(response, indent=2))
    return response
Пример #17
0
class APIClient(object):
    def __new__(cls, baseUrl, token):
        if token is None:
            raise ValueError
        else:
            return super(APIClient, cls).__new__(cls)

    def __init__(self, baseUrl, token):
        self.base_url = baseUrl + '/v3/graphql'
        self.headers = {
            'Authorization': 'Bearer %s' % token,
        }
        transport = RequestsHTTPTransport(self.base_url, headers=self.headers,
                                          use_json=True, timeout=DEFAULT_REQUEST_TIMEOUT)
        self.client = Client(transport=transport, fetch_schema_from_transport=True)

    def get_recording(self, recording_id, asset_type):
        query = gql('''
            query{
              temporalDataObject(id:"%s"){
                assets(assetType:%s) {
                  records  {
                    id
                    assetType
                    contentType
                    createdDateTime
                    signedUri
                  }
                }
              }
            }
            ''' % (recording_id, asset_type))
        try:
            response = self.client.execute(query)
            return {
                'assets': response['temporalDataObject']['assets']['records']
            }
        except Exception as e:
            print('Failed to find {} for recording_id {} due to: {}'.format(asset_type, recording_id, e))
            return None

    def save_transcript(self, recording_id, assetType, contentType, engineId, language, transcript):
        if assetType == 'text':
            filename = 'translation.txt'
            file_content = transcript
        else:
            filename = 'translation.ttml'
            file_content = xmltodict.unparse(transcript, pretty=True)

        query = '''
            mutation {
              createAsset(
                input: {
                    containerId: "%s",
                    assetType: "%s",
                    contentType: "%s",
                    name: "%s",
                    jsondata:{
                        source: "%s",
                        language: "%s"
                    }
                }) {
                id
                signedUri
              }
            }
            ''' % (recording_id, assetType, contentType, filename, engineId, language)

        data = {
            'query': query,
            'filename': filename
        }

        files = {
            'file': (filename, file_content.encode('utf-8'))
        }

        try:
            response = requests.post(self.base_url, data=data, files=files, headers=self.headers)
            return response.status_code == HTTPStatus.OK
        except Exception as e:
            print('Failed to create asset for recording: {} due to: {}'.format(recording_id, e))
            return False

    def update_task(self, job_id, task_id, status, output=None):
        if status not in VALID_TASK_STATUS:
            return False

        if output is None:
            output = {}

        query = gql('''
            mutation {
              updateTask(input: {
                    id: "%s",
                    jobId: "%s",
                    status: %s,
                    output: %s
                }) {
                id
                status
              }
            }
        ''' % (task_id, job_id, status, json.dumps(output)))
        try:
            self.client.execute(query)
        except Exception as e:
            print('Failed to update task {} status to {} due to: {}'.format(task_id, status, e))
            return False
Пример #18
0
_transport = RequestsHTTPTransport(
    url='https://api.datacite.org/graphql',
    use_json=True,
)

client = Client(
    transport=_transport,
    fetch_schema_from_transport=True,
)
query = gql("""{
   researcher(id: "https://orcid.org/0000-0003-1419-2405") {
    id
    name
    publications(first: 50) {
      totalCount
      nodes {
        id
        relatedIdentifiers {
          relatedIdentifier
        }
      }
    }
  }
}""")
data = client.execute(query)
print(data["researcher"]["name"])

df = pd.DataFrame(data["researcher"]["publications"]["nodes"])
print(df)
#print(data.researcher.name)
query FindRecentlyClosedIssues($repository_owner: String!,
                               $repository_name: String!,
                               $issues_last: Int = 20,
                               $issues_states: [IssueState!] = CLOSED,
                               $labels_first: Int = 5) {
  repository(owner: $repository_owner, name: $repository_name) {
    issues(last: $issues_last, states: $issues_states) {
      edges {
        node {
          title
          url
          labels(first: $labels_first) {
            edges {
              node {
                name
              }
            }
          }
        }
      }
    }
  }
}
''')

pprint(
    client.execute(query, {
        'repository_owner': 'octocat',
        'repository_name': 'Hello-World'
    }))
Пример #20
0
class SynthweetixBot:
    def __init__(self,
                 key,
                 secret,
                 access_token,
                 access_secret,
                 etherscan_api_key,
                 trade_value_threshold=250000,
                 short_position_value_threshold=100000,
                 eye_catcher_threshold=1000000,
                 debug=False):
        auth = OAuthHandler(key, secret)
        auth.set_access_token(access_token, access_secret)
        self.api = API(auth)

        self.etherscan_api_key = etherscan_api_key

        # Trades
        transport = RequestsHTTPTransport(
            url=EXCHANGE_SUBGRAPH_API_ENDPOINT,
            verify=True,
            retries=3,
        )
        self.gql_client_synthetix_exchanges = Client(
            transport=transport, fetch_schema_from_transport=True)

        # Cross-asset Swaps
        transport = RequestsHTTPTransport(
            url=CURVE_SUBGRAPH_API_ENDPOINT,
            verify=True,
            retries=3,
        )
        self.gql_client_curve = Client(transport=transport,
                                       fetch_schema_from_transport=True)

        # Short positions
        transport = RequestsHTTPTransport(
            url=SHORTS_SUBGRAPH_API_ENDPOINT,
            verify=True,
            retries=3,
        )
        self.gql_client_synthetix_shorts = Client(
            transport=transport, fetch_schema_from_transport=True)

        # CoinGecko
        self.cg = CoinGeckoAPI()

        self.trade_value_threshold = trade_value_threshold
        self.short_position_value_threshold = short_position_value_threshold
        self.eye_catcher_threshold = eye_catcher_threshold

        self.timestamp_last_fetch = int(time.time())
        self.debug = debug

    def send_tweet(self, type_: ExchangeType, message):
        message = f'\U0001F4B0 #Synthetix High Roller {type_.value} \U0001F4B0\n' \
                  f'{message}'

        logging.info(message)
        if not self.debug:
            try:
                self.api.update_status(message)
            except TweepError as e:
                logging.warning(e)

    def fetch_trades(self):
        query = gql(f"""
            query getSynthExchanges {{
                synthExchanges (
                        where: {{ 
                            timestamp_gte: {self.timestamp_last_fetch}
                        }}, orderBy: timestamp, orderDirection: asc) 
                {{
                    id
                    account
                    from
                    fromCurrencyKey
                    fromAmount
                    fromAmountInUSD
                    toCurrencyKey
                    toAmount
                    toAmountInUSD
                    feesInUSD
                    timestamp
                }}
            }}
            """)
        result = self.gql_client_synthetix_exchanges.execute(query)
        return result.get('synthExchanges')

    def fetch_vyper_transactions(self):
        txs = []

        try:
            with requests.get(
                    f'https://api.etherscan.io/api?module=block&action=getblocknobytime'
                    f'&timestamp={self.timestamp_last_fetch}&closest=after&'
                    f'apikey={self.etherscan_api_key}') as r:
                start_block = int(json.loads(r.text)['result'])

            with requests.get(
                    f'https://api.etherscan.io/api?module=account&action=txlist&'
                    f'address={ETHERSCAN_VYPER_CONTRACT}&startblock={start_block}&'
                    f'sort=asc&apikey={self.etherscan_api_key}') as r:
                txs = json.loads(r.text)['result']
        except ValueError as e:
            logging.warning(e)

        return txs

    def fetch_curve_swaps(self):
        query = gql(f"""
            query swaps {{
                swaps (
                        where: {{ 
                            timestamp_gte: {self.timestamp_last_fetch}
                        }}, orderBy: timestamp, orderDirection: asc) 
                {{
                    fromToken {{
                      symbol
                    }}
                    fromTokenAmountDecimal
                    toToken {{
                      symbol
                    }}
                    toTokenAmountDecimal
                    underlyingPrice
                    timestamp
                    transaction {{
                      hash
                    }}
                }}
            }}
            """)
        result = self.gql_client_curve.execute(query)
        return result.get('swaps')

    def fetch_shorts(self):
        query = gql(f"""
            query shorts {{
                shorts (
                        where: {{ 
                            isOpen: true,
                            createdAt_gte: {self.timestamp_last_fetch}
                        }}, orderBy: createdAt, orderDirection: asc) 
                {{
                    id
                    txHash
                    account
                    collateralLocked
                    collateralLockedAmount
                    synthBorrowed
                    synthBorrowedAmount
                    createdAt
                }}
            }}
            """)
        result = self.gql_client_synthetix_shorts.execute(query)
        return result.get('shorts')

    def create_trades_tweets(self, trades):
        for trade in trades:
            account = trade.get('account')

            from_ = trade.get('from')
            from_amount = float(trade.get('fromAmount')) / 1e18
            from_currency = bytes.fromhex(
                trade.get('fromCurrencyKey')[2:10]).decode('utf-8')
            from_amount_usd = float(trade.get('fromAmountInUSD')) / 1e18

            to_ = trade.get('to')
            to_amount = float(trade.get('toAmount')) / 1e18
            to_currency = bytes.fromhex(
                trade.get('toCurrencyKey')[2:10]).decode('utf-8')
            to_amount_usd = float(trade.get('toAmountInUSD')) / 1e18

            fees_usd = float(trade.get('feesInUSD')) / 1e18

            message = []
            if to_amount_usd >= self.eye_catcher_threshold or from_amount_usd >= self.eye_catcher_threshold:
                r = randint(0, 2)
                message.append(EYE_CATCHERS[r])
            message.extend([
                'FROM {:,.2f} {} (${:,.2f})'.format(from_amount, from_currency,
                                                    from_amount_usd),
                'TO {:,.2f} {} (${:,.2f})'.format(to_amount, to_currency,
                                                  to_amount_usd),
                'FEES = {:,.2f}'.format(fees_usd),
                'https://etherscan.io/address/{}'.format(account),
            ])

            self.send_tweet(ExchangeType.TRADES, '\n'.join(message))

    def create_swaps_tweets(self, swaps):
        for swap in swaps:
            transaction = swap.get('transaction').get('hash')

            from_token = swap.get('fromToken').get('symbol')
            from_token_amount = float(swap.get('fromTokenAmountDecimal'))
            from_token_amount_usd = swap.get('fromTokenAmountUSD')

            to_token = swap.get('toToken').get('symbol')
            to_token_amount = float(swap.get('toTokenAmountDecimal'))
            to_token_amount_usd = swap.get('toTokenAmountUSD')

            message = []
            if from_token_amount_usd >= self.eye_catcher_threshold or to_token_amount_usd >= self.eye_catcher_threshold:
                r = randint(0, 2)
                message.append(EYE_CATCHERS[r])
            message.extend([
                'FROM {:,.2f} {} (${:,.2f})'.format(from_token_amount,
                                                    from_token,
                                                    from_token_amount_usd),
                'TO {:,.2f} {} (${:,.2f})'.format(to_token_amount, to_token,
                                                  to_token_amount_usd),
                'https://etherscan.io/tx/{}'.format(transaction),
            ])

            self.send_tweet(ExchangeType.SWAPS, '\n'.join(message))

    def create_shorts_tweets(self, shorts):
        for short in shorts:
            tx_hash = short.get('txHash')

            synth_borrowed = short.get('synthBorrowed')
            synth_borrowed_amount = short.get('synthBorrowedAmount')
            synth_borrowed_amount_usd = short.get('synthBorrowedAmountUSD')

            collateral_locked = short.get('collateralLocked')
            collateral_locked_amount = short.get('collateralLockedAmount')
            collateral_locked_amount_usd = short.get(
                'collateralLockedAmountUSD')

            message = []
            if synth_borrowed_amount_usd >= self.eye_catcher_threshold:
                r = randint(0, 2)
                message.append(EYE_CATCHERS[r])
            message.extend([
                'SYNTH BORROWED {:,.2f} {} (${:,.2f})'.format(
                    synth_borrowed_amount, synth_borrowed,
                    synth_borrowed_amount_usd),
                'COLLATERAL LOCKED {:,.2f} {} (${:,.2f})'.format(
                    collateral_locked_amount, collateral_locked,
                    collateral_locked_amount_usd),
                'https://etherscan.io/tx/{}'.format(tx_hash),
            ])

            self.send_tweet(ExchangeType.SHORTS, '\n'.join(message))

    def execute(self):
        start = datetime.now()
        logging.info('Running SynthweetixBot')

        try:
            prices = {}

            # Trades
            logging.info(
                f'Fetching trades from TheGraph at {self.gql_client_synthetix_exchanges.transport}'
            )
            whales = [
                trade for trade in self.fetch_trades()
                if float(trade.get('toAmountInUSD')) /
                1e18 >= self.trade_value_threshold
            ]

            logging.info('Sending tweets for trades')
            self.create_trades_tweets(whales)

            # Cross-asset Swaps
            logging.info(
                f'Fetching cross-asset swaps from TheGraph at {self.gql_client_curve.transport}'
            )
            swaps = self.fetch_curve_swaps()
            whales = []

            txs = self.fetch_vyper_transactions()

            for swap in swaps:
                if swap.get('transaction').get(
                        'hash'
                ) in txs:  # Check if swap is an actual Cross-Asset Swap
                    from_token = swap.get('fromToken').get('symbol')
                    if from_token not in prices.keys():
                        prices[from_token] = cryptocompare.get_price(from_token, currency='usd') \
                            .get(from_token.upper()).get('USD')

                    to_token = swap.get('toToken').get('symbol')
                    if to_token not in prices.keys():
                        prices[to_token] = cryptocompare.get_price(to_token, currency='usd') \
                            .get(to_token.upper()).get('USD')

                    from_token_amount_usd = float(
                        swap.get('fromTokenAmountDecimal')) * prices.get(
                            from_token)
                    to_token_amount_usd = float(
                        swap.get('toTokenAmountDecimal')) * prices.get(
                            to_token)
                    if to_token_amount_usd >= self.trade_value_threshold:
                        swap['fromTokenAmountUSD'] = from_token_amount_usd
                        swap['toTokenAmountUSD'] = to_token_amount_usd
                        whales.append(swap)

            logging.info('Sending tweets for cross-asset swaps')
            self.create_swaps_tweets(whales)

            # Short positions
            logging.info(
                f'Fetching short positions from TheGraph at {self.gql_client_synthetix_shorts.transport}'
            )
            shorts = self.fetch_shorts()
            whales = []

            for short in shorts:
                synth_token = bytes.fromhex(
                    short.get('synthBorrowed')[2:10]).decode('utf-8')
                if synth_token not in prices.keys():
                    prices[synth_token] = self.cg.get_price(ids=synth_token, vs_currencies='usd') \
                        .get(synth_token.lower()).get('usd')

                collateral_token = bytes.fromhex(
                    short.get('collateralLocked')[2:10]).decode('utf-8')
                if collateral_token not in prices.keys():
                    prices[collateral_token] = cryptocompare.get_price(collateral_token, currency='usd') \
                        .get(collateral_token.upper()).get('USD')

                synth_borrowed_amount = float(
                    short.get('synthBorrowedAmount')) / 1e18
                synth_borrowed_amount_usd = synth_borrowed_amount * prices.get(
                    synth_token)

                collateral_locked_amount = float(
                    short.get('collateralLockedAmount')) / 1e18
                collateral_locked_amount_usd = collateral_locked_amount * prices.get(
                    collateral_token)

                if synth_borrowed_amount_usd >= self.short_position_value_threshold:
                    short['synthBorrowed'] = synth_token
                    short['synthBorrowedAmount'] = synth_borrowed_amount
                    short['synthBorrowedAmountUSD'] = synth_borrowed_amount_usd

                    short['collateralLocked'] = collateral_token
                    short['collateralLockedAmount'] = collateral_locked_amount
                    short[
                        'collateralLockedAmountUSD'] = collateral_locked_amount_usd
                    whales.append(short)

            logging.info('Sending tweets for short positions')
            self.create_shorts_tweets(whales)

            self.timestamp_last_fetch = int(time.time())
        except requests.RequestException as e:
            logging.error(e)

        end = datetime.now()
        logging.info(f'Executed SynthweetixBot in {end - start}s')
Пример #21
0
class AlectioClient:
    def __init__(self, environ=os.environ, api_key = None):
        self._environ = environ
            
        if not api_key and 'ALECTIO_API_KEY' not in self._environ and not credentials_exists:
            raise APIKeyNotFound

        self._client_token = None
        # cli user settings
        self._settings = {
            'git_remote': "origin",
            #'base_url': "https://api.alectio.com"
            'env_base_url': "http://localhost:5005",
            'prod_base_url': "https://api.alectio.com"
        }

            
        if api_key:
            os.environ['ALECTIO_API_KEY'] = api_key
            self._environ['ALECTIO_API_KEY'] = api_key
            self._api_key = self._environ['ALECTIO_API_KEY'] if 'ALECTIO_API_KEY' in self._environ else api_key
            try:
                headers = {"Authorization": "Bearer " + api_key}
                r = requests.get(self._settings['env_base_url']+'/api/me', headers=headers)
                if r.status_code == 401:
                    print("Expired or Wrong API Key. Please run alectio login YOU_API_KEY")
                elif r.status_code == 200:
                    self._user_id = r.get_json()['id']
                else:
                    print('Something Went wrong. Check you API Key')
            except:
                print('\n Alectio Error: Something Went wrong.\n')

        else:
            self._api_key_data = get_credentials()
            print(self._api_key_data)
            os.environ['ALECTIO_API_KEY'] = self._api_key_data['api_key']
            self._environ['ALECTIO_API_KEY'] =  self._api_key_data['api_key']
            self._api_key = self._api_key_data['api_key']
            self._user_id = self._api_key_data['user_id']


        # self._endpoint = f'{self._settings['base_url']}/graphql'
        self._endpoint = self._settings['env_base_url'] + "/graphql"

        # graphql client
        self._client = Client(
            transport=RequestsHTTPTransport(
                url=self._endpoint,
                verify=False,
                retries=3,
                headers={'APIKEY': self._api_key}
            ),
            fetch_schema_from_transport=True,
        )

        # client to upload files, images, etc.
        # uses https://pypi.org/project/aiogqlc/
        self._upload_client = GraphQLClient(self._endpoint) # change for dev
        self._oauth_server = 'https://auth.alectio.com/'
        # need to retrive user_id based on token @ DEVI from OPENID
        self._user_id = "82b4fb909f1f11ea9d300242ac110002" # ideally this should be set already. Dummy one will be set when init is invoked
        # compnay id = 7774e1ca972811eaad5238c986352c36s
        # self.dir_path = os.path.dirname(os.path.realpath(__file__))


    def set_user_id(self, token):
        headers = {"Authorization": "Bearer " + self._client_token['access_token']}
        requests_data = requests.get(
            url=self._oauth_server + 'api/me', headers=headers)
        if requests_data.status_code == 401:
            print("Client Token Expired. Please run alectio-kms --refresh")
        elif requests_data.status_code == 200:
            requests_data = requests_data.json()
            self._user_id = requests_data['id']

    def init(self, token=None,file_path='/opt/alectio/client_token.json'):
        if token:
            self.set_user_id(token)
        else:
            if file_path and self._client_token is None:
                with open(file_path, 'r') as f:
                    self._client_token = json.load(f)
                self.set_user_id(self._client_token['access_token'])


    def get_single(self, resource, query_string, params):
        """
        return a single object for the requested resource.
        :params: resource - name of the resource to obtain i.e experiments, projects, models, etc.
        :params: query_string - graphql string to invoke.
        :params: params - variables required to invoke query string in the client.
        """
        query = gql(query_string)
        class_name = lambda class_name: getattr(sys.modules[__name__], class_name)
        singular = self._client.execute(query, params)[resource][0]
        class_to_init = class_name(resource.title())
        hash_key =  extract_id(singular['pk'])
        if resource == "project":
           hash_key = extract_id(singular['sk'])

        if not resource == "job":
            singular_object = class_to_init(self._client, singular, self._user_id, hash_key)
            return singular_object

        # job object class is slighty different in design
        singular_object = class_to_init(self._upload_client, singular, self._user_id, hash_key)
        return singular_object


    def mutate_single(self, query_string, params):
        """
        return a single object for the requested resource.
        :params: resource - name of the resource to obtain i.e experiments, projects, models, etc.
        :params: query_string - graphql string to invoke.
        :params: params - variables required to invoke query string in the client.
        """
        query = gql(query_string)
        singular = self._client.execute(query, params)
        print(singular)
        return singular


    def get_collection(self, resource, query_string, params):
        """
        return a collection of objects for the requested resource.
        :params: resource - name of the resource to obtain i.e experiments, projects, models, etc.
        :params: query_string - graphql string to invoke.
        :params: params - variables required to invoke query string in the client.
        """
        query = gql(query_string)
        singular_resource =  lambda resource_name: str(resource_name.title()[0:-1]) # format resource name to match one of the existing classes
        class_name = lambda class_name: getattr(sys.modules[__name__], class_name) # convert string to class name
        collection = self._client.execute(query, params)[resource]
        class_to_init = class_name(singular_resource(resource))

        collection_objects = []
        if not resource == "jobs":
            collection_objects = [class_to_init(self._client, item, self._user_id, extract_id(item['sk'])) for item in collection]
            return collection_objects
        # jobs resource
        collection_objects = [class_to_init(self._upload_client, item, self._user_id, extract_id(item['pk'])) for item in collection]
        return collection_objects


    def projects(self):
        """
        retrieve user projects
        :params: user_id - a uuid
        """
        params = {
            "id": str(self._user_id),
        }
        return self.get_collection("projects", PROJECTS_QUERY_FRAGMENT, params)


    def experiments(self, project_id):
        """
        retreive experiments that belong to a project
        :params: project_id - a uuid
        """
        params = {
            "id": str(project_id),
        }
        return self.get_collection("experiments", EXPERIMENTS_QUERY_FRAGMENT, params)


    def experiment(self, experiment_id):
        """
        retreive experiments that belong to a project
        :params: project_id - a uuid
        """
        params = {
            "id": str(experiment_id),
        }
        return self.get_single("experiment", EXPERIMENT_QUERY_FRAGMENT, params)


    # grab user id + project id
    def project(self, project_id):
        """
        retrieve a single user project
        :params: project_id - a uuid
        """
        params = {
             "userId": str(self._user_id),
             "projectId": str(project_id)
        }
        return self.get_single("project", PROJECT_QUERY_FRAGMENT, params)


    def models(self, organization_id):
        """
        retrieve models associated with a user / organization.
        :params: project_id - a uuid
        """
        params = {
            "id": str(organization_id),
        }
        return self.get_collection("models", MODELS_QUERY_FRAGMENT, params)


    def model(self, model_id):
        """
        retrieve a single user model
        :params: project_id - a uuid
        """
        params = {
            "id": str(model_id),
        }
        return self.get_single("model", MODEL_QUERY_FRAGMENT, params)


    def job(self, job_id, project_id):
        """
        returns a single labeling job
        :params: job_id - job uuid
        """
        params = {
            "id": str(job_id),
            "projectId": str(project_id)
        }
        return self.get_single("job", JOB_QUERY_FRAGMENT, params)


    def create_project(self, file):
        """
        create user project
        """
        with open(file, 'r') as yaml_in:
            yaml_object = yaml.safe_load(yaml_in) # yaml_object will be a list or a dict
            project_dict = yaml_object['Project']
            project_dict['userId'] = self._user_id
            now = datetime.now()
            project_dict['date'] = now.strftime("%Y/%m/%d") # We probably should not allow cli to decide DATE
            print(project_dict)
            params = project_dict
            response =  self.mutate_single(PROJECT_CREATE_FRAGMENT, params)
            new_project_created = response["createProject"]["ok"]
            return self.project(new_project_created)

        return f"Failed to open file {file}"


    def upload_class_labels(self, class_labels_file, project_id):
        """
        Precondition: create_project has been called
        This method will upload the meta.json file to the S3 storage bucket of the newly created project.
        """

        data = {}

        with open(class_labels_file) as f:
            data = json.load(f) #serialize json object to a string to be sent to server
            #upload to project_id/meta.json

            data = json.dumps(data)

            params = {
                "userId": self._user_id,
                "projectId": project_id,
                "classLabels": data
            }

            response = self.mutate_single(UPLOAD_CLASS_LABELS_MUTATION, params)


    def create_experiment(self, file):
        """
        create user experient
        """

        env = EnvYAML(file)
        project_id = env['Experiment.projectId']
        if project_id is None:
            print("no project id was set")
            return

        with open(file, 'r') as yaml_in:
            yaml_object = yaml.safe_load(yaml_in) # yaml_object will be a list or a dict
            experiment_dict = yaml_object['Experiment']
            experiment_dict['experimentId'] = "".join(str(uuid.uuid1()).split("-"))
            experiment_dict['userId'] = self._user_id
            now = datetime.now()
            experiment_dict['date'] = now.strftime("%m-%d-%Y")
            # project id from the env variable
            experiment_dict['projectId'] = project_id
            response =  self.mutate_single(EXPERIMENT_CREATE_FRAGMENT, experiment_dict)
            new_experiment_created = response["createExperiment"]["ok"]

        return self.experiment(new_experiment_created)


    def create_model(self, model_path):
        """
        upload model checksum and verify there are enough models to check
        :returns: model object
        """
        return Model("", "", "", "")
Пример #22
0
class NewRelicGQL(object):
    def __init__(self, account_id, api_key, region="us"):
        try:
            self.account_id = int(account_id)
        except ValueError:
            raise ValueError("Account ID must be an integer")

        self.api_key = api_key

        if region == "us":
            self.url = "https://api.newrelic.com/graphql"
        elif region == "eu":
            self.url = "https://api.eu.newrelic.com/graphql"
        else:
            raise ValueError("Region must be one of 'us' or 'eu'")

        transport = RequestsHTTPTransport(url=self.url, use_json=True)
        transport.headers = {"api-key": self.api_key}

        try:
            self.client = Client(transport=transport,
                                 fetch_schema_from_transport=True)
        except Exception:
            self.client = Client(transport=transport,
                                 fetch_schema_from_transport=False)

    def query(self, query, timeout=None, **variable_values):
        return self.client.execute(gql(query),
                                   timeout=timeout,
                                   variable_values=variable_values or None)

    def get_license_key(self):
        """
        Fetch the license key for the NR Account
        """
        res = self.query(
            """
            query ($accountId: Int!) {
              requestContext {
                apiKey
              }
              actor {
                account(id: $accountId) {
                  licenseKey
                  id
                  name
                }
              }
            }
            """,
            accountId=self.account_id,
        )
        try:
            return res["actor"]["account"]["licenseKey"]
        except KeyError:
            return None

    def get_linked_accounts(self):
        """
        Return a list of linked accounts for the New Relic account
        """
        res = self.query(
            """
            query ($accountId: Int!) {
              actor {
                account(id: $accountId) {
                  cloud {
                    linkedAccounts {
                      id
                      externalId
                      name
                      authLabel
                    }
                  }
                }
              }
            }
            """,
            accountId=self.account_id,
        )
        try:
            return res["actor"]["account"]["cloud"]["linkedAccounts"]
        except KeyError:
            return []

    def link_account(self, role_arn, account_name):
        """
        Create a linked account (cloud integrations account)
        in the New Relic account
        """
        res = self.query(
            """
            mutation ($accountId: Int!, $accounts: CloudLinkCloudAccountsInput!){
              cloudLinkAccount (accountId: $accountId, accounts: $accounts) {
                linkedAccounts {
                  id
                  name
                }
                errors {
                    message
                }
              }
            }
            """,
            accountId=self.account_id,
            accounts={"aws": {
                "arn": role_arn,
                "name": account_name
            }},
        )
        try:
            return res["cloudLinkAccount"]["linkedAccounts"][0]["id"]
        except (IndexError, KeyError):
            if "errors" in res:
                print("Error while linking account with New Relic: {0}".format(
                    res["errors"]))

            return None

    def unlink_account(self, linked_account_id):
        """
        Unlink a New Relic Cloud integrations account
        """
        res = self.query(
            """
            mutation ($accountId: Int!, $accounts: [CloudUnlinkAccountsInput!]!) {
              cloudUnlinkAccount(accountId: $accountId, accounts: $accounts) {
                errors {
                  message
                  type
                }
                unlinkedAccounts {
                  id
                  name
                }
              }
            }
            """,
            accountId=self.account_id,
            accounts={"linkedAccountId": linked_account_id},
        )

        if "errors" in res:
            print("Error while unlinking account with New Relic: {0}".format(
                res["errors"]))

        return res

    def enable_integration(self, linked_account_id, provider_slug,
                           service_slug, aws_region):
        """
        Enable monitoring of a Cloud provider's service (integration)
        """
        integrations = {
            provider_slug: {
                service_slug: [{
                    "linkedAccountId": linked_account_id,
                    "awsRegions": aws_region
                }]
            }
        }
        if service_slug in ["iam", "s3", "route53", "billing"]:
            del integrations[provider_slug][service_slug][0]["awsRegions"]
            if service_slug == "iam":
                integrations[provider_slug][service_slug][0][
                    "tagKey"] = "Region"
                integrations[provider_slug][service_slug][0][
                    "tagValue"] = aws_region

        res = self.query(
            """
            mutation ($accountId: Int!, $integrations: CloudIntegrationsInput!) {
              cloudConfigureIntegration (
                accountId: $accountId,
                integrations: $integrations
              ) {
                integrations {
                  id
                  name
                  service {
                    id
                    name
                  }
                }
                errors {
                  linkedAccountId
                  message
                }
              }
            }
            """,
            accountId=self.account_id,
            integrations=integrations,
        )
        try:
            return res["cloudConfigureIntegration"]["integrations"][0]
        except (IndexError, KeyError):
            if "errors" in res:
                print("Error while enabling integration with New Relic:\n{0}".
                      format(res["errors"]))

            return None
Пример #23
0
}]
# variables = {"id": "web1-4045064"}

# query CategoryPagination($id: ID!, $page: Int!, $pageSize: Int!, $hasItems: Boolean!, $hasEntries: Boolean!, $facetValues: FacetValuesInput) {  category(id: $id) {    id    name    entries(page: $page, pageSize: $pageSize, facetValues: $facetValues) @include(if: $hasEntries) {      pagination {        page        totalPages        totalResults        __typename}      __typename}    items(page: $page, pageSize: $pageSize, facetValues: $facetValues) @include(if: $hasItems) {      pagination {        page        totalPages        totalResults        __typename}      __typename}    __typename}}

from gql import gql, Client
from gql.transport.aiohttp import AIOHTTPTransport

transport = AIOHTTPTransport(url="https://www.kramp.com/graphql/checkout-app",
                             headers={"ctx-locale": "pl_PL"})

client = Client(transport=transport, fetch_schema_from_transport=True)
query = gql(query)

for var in vars:
    result = client.execute(query,
                            variable_values=var)  #, variable_values=variables
    print(result)
    print("\n\n\n")
'''
query = """query GetCategoryProducts($categoryId: ID!, $isAuthenticated: Boolean!, $pageSize: Int!, $page: Int!, $facetValues: FacetValuesInput) {  category(id: $categoryId) {    id    name    items(page: $page, pageSize: $pageSize, facetValues: $facetValues) {      pagination {        page        totalPages        totalResults        __typename      }      items {        id        name        description        brand {          id          name          logo {            src            alt            __typename          }          __typename        }        classifications {          code          values {            key            value            __typename          }          __typename        }        image {          src          alt          __typename        }        quantity        roundingQuantity        minimumQuantity        variant {          id          name          __typename        }        hasVolumeDiscount @include(if: $isAuthenticated)        grossPrice @include(if: $isAuthenticated) {          value          currency          __typename        }        __typename      }      __typename    }
  __typename  }}"""

vars = {
    "categoryId": "web2-4045755",
    "isAuthenticated": false,
    "pageSize": 60,
    "page": 1,
    "facetValues": {
        "multi": [],
        "range": [],
        "single": []
Пример #24
0
def retrieve_tls_guidance():
    logging.info("Retrieving TLS guidance...")

    if not all(
        i is not None for i in [REPO_NAME, REPO_OWNER, GUIDANCE_DIR, GITHUB_TOKEN]
    ):
        logging.error(
            "Missing one or more secrets required for TLS guidance retrieval. SSL results may not reflect compliance."
        )
        return {
            "ciphers": {
                "1.2": {"recommended": [], "sufficient": [], "phase_out": []},
                "1.3": {"recommended": [], "sufficient": []},
            },
            "curves": {"recommended": [], "sufficient": [], "phase_out": []},
            "signature_algorithms": {
                "recommended": [],
                "sufficient": [],
                "phase_out": [],
            },
            "extensions": {"1.2": {"recommended": []}, "1.3": {"recommended": []}},
        }

    try:
        gh_client = Client(
            transport=RequestsHTTPTransport(
                url="https://api.github.com/graphql",
                headers={"Authorization": "bearer " + GITHUB_TOKEN},
            ),
        )

        # fmt: off
        guidance_query = """
        {{
          repository(name: "{REPO_NAME}", owner: "{REPO_OWNER}") {{
            id
            object(expression: "main:{GUIDANCE_DIR}/{GUIDANCE_FILE}") {{
              ... on Blob {{
                text
              }}
            }}
          }}
        }}
        """.format(**{"REPO_NAME": REPO_NAME, "REPO_OWNER": REPO_OWNER, "GUIDANCE_DIR": GUIDANCE_DIR, "GUIDANCE_FILE": GUIDANCE_FILE})
        # fmt: on
        guidance_result = gh_client.execute(gql(guidance_query))
        guidance = json.loads(guidance_result["repository"]["object"]["text"])

        logging.info(f"TLS guidance retrieved.")
        return guidance
    except Exception as e:
        logging.error(
            f"Error occurred while retrieving TLS guidance. SSL results may not reflect compliance. {str(e)} \n\nFull traceback: {traceback.format_exc()}"
        )
        return {
            "ciphers": {
                "1.2": {"recommended": [], "sufficient": [], "phase_out": []},
                "1.3": {"recommended": [], "sufficient": []},
            },
            "curves": {"recommended": [], "sufficient": [], "phase_out": []},
            "signature_algorithms": {
                "recommended": [],
                "sufficient": [],
                "phase_out": [],
            },
            "extensions": {"1.2": {"recommended": []}, "1.3": {"recommended": []}},
        }
Пример #25
0
class Api(object):
    """W&B Internal Api wrapper

    Note:
        Settings are automatically overridden by looking for
        a `wandb/settings` file in the current working directory or it's parent
        directory.  If none can be found, we look in the current users home
        directory.

    Args:
        default_settings(:obj:`dict`, optional): If you aren't using a settings
        file or you wish to override the section to use in the settings file
        Override the settings here.
    """

    HTTP_TIMEOUT = env.get_http_timeout(10)

    def __init__(self,
                 default_settings=None,
                 load_settings=True,
                 retry_timedelta=datetime.timedelta(days=1),
                 environ=os.environ):
        self._environ = environ
        self.default_settings = {
            'section': "default",
            'run': "latest",
            'git_remote': "origin",
            'ignore_globs': [],
            'base_url': "https://api.wandb.ai"
        }
        self.retry_timedelta = retry_timedelta
        self.default_settings.update(default_settings or {})
        self.retry_uploads = 10
        self._settings = Settings(load_settings=load_settings)
        self.git = GitRepo(remote=self.settings("git_remote"))
        # Mutable settings set by the _file_stream_api
        self.dynamic_settings = {
            'system_sample_seconds': 2,
            'system_samples': 15,
            'heartbeat_seconds': 30,
        }
        self.client = Client(transport=RequestsHTTPTransport(
            headers={
                'User-Agent': self.user_agent,
                'X-WANDB-USERNAME': env.get_username(env=self._environ)
            },
            use_json=True,
            # this timeout won't apply when the DNS lookup fails. in that case, it will be 60s
            # https://bugs.python.org/issue22889
            timeout=self.HTTP_TIMEOUT,
            auth=("api", self.api_key or ""),
            url='%s/graphql' % self.settings('base_url')))
        self.gql = retry.Retry(
            self.execute,
            retry_timedelta=retry_timedelta,
            check_retry_fn=util.no_retry_auth,
            retryable_exceptions=(RetryError, requests.RequestException))
        self._current_run_id = None
        self._file_stream_api = None

    def reauth(self):
        """Ensures the current api key is set in the transport"""
        self.client.transport.auth = ("api", self.api_key or "")

    def execute(self, *args, **kwargs):
        """Wrapper around execute that logs in cases of failure."""
        try:
            return self.client.execute(*args, **kwargs)
        except requests.exceptions.HTTPError as err:
            res = err.response
            logger.error("%s response executing GraphQL." % res.status_code)
            logger.error(res.text)
            self.display_gorilla_error_if_found(res)
            six.reraise(*sys.exc_info())

    def display_gorilla_error_if_found(self, res):
        try:
            data = res.json()
        except ValueError:
            return

        if 'errors' in data and isinstance(data['errors'], list):
            for err in data['errors']:
                if not err.get('message'):
                    continue
                wandb.termerror('Error while calling W&B API: %s' %
                                err['message'])

    def disabled(self):
        return self._settings.get(Settings.DEFAULT_SECTION,
                                  'disabled',
                                  fallback=False)

    def save_pip(self, out_dir):
        """Saves the current working set of pip packages to requirements.txt"""
        try:
            import pkg_resources

            installed_packages = [d for d in iter(pkg_resources.working_set)]
            installed_packages_list = sorted(
                ["%s==%s" % (i.key, i.version) for i in installed_packages])
            with open(os.path.join(out_dir, 'requirements.txt'), 'w') as f:
                f.write("\n".join(installed_packages_list))
        except Exception as e:
            logger.error("Error saving pip packages")

    def save_patches(self, out_dir):
        """Save the current state of this repository to one or more patches.

        Makes one patch against HEAD and another one against the most recent
        commit that occurs in an upstream branch. This way we can be robust
        to history editing as long as the user never does "push -f" to break
        history on an upstream branch.

        Writes the first patch to <out_dir>/diff.patch and the second to
        <out_dir>/upstream_diff_<commit_id>.patch.

        Args:
            out_dir (str): Directory to write the patch files.
        """
        if not self.git.enabled:
            return False

        try:
            root = self.git.root
            if self.git.dirty:
                patch_path = os.path.join(out_dir, 'diff.patch')
                if self.git.has_submodule_diff:
                    with open(patch_path, 'wb') as patch:
                        # we diff against HEAD to ensure we get changes in the index
                        subprocess.check_call(
                            ['git', 'diff', '--submodule=diff', 'HEAD'],
                            stdout=patch,
                            cwd=root,
                            timeout=5)
                else:
                    with open(patch_path, 'wb') as patch:
                        subprocess.check_call(['git', 'diff', 'HEAD'],
                                              stdout=patch,
                                              cwd=root,
                                              timeout=5)

            upstream_commit = self.git.get_upstream_fork_point()
            if upstream_commit and upstream_commit != self.git.repo.head.commit:
                sha = upstream_commit.hexsha
                upstream_patch_path = os.path.join(
                    out_dir, 'upstream_diff_{}.patch'.format(sha))
                if self.git.has_submodule_diff:
                    with open(upstream_patch_path, 'wb') as upstream_patch:
                        subprocess.check_call(
                            ['git', 'diff', '--submodule=diff', sha],
                            stdout=upstream_patch,
                            cwd=root,
                            timeout=5)
                else:
                    with open(upstream_patch_path, 'wb') as upstream_patch:
                        subprocess.check_call(['git', 'diff', sha],
                                              stdout=upstream_patch,
                                              cwd=root,
                                              timeout=5)
        # TODO: A customer saw `ValueError: Reference at 'refs/remotes/origin/foo' does not exist`
        # so we now catch ValueError.  Catching this error feels too generic.
        except (ValueError, subprocess.CalledProcessError,
                subprocess.TimeoutExpired) as e:
            logger.error('Error generating diff: %s' % e)

    def set_current_run_id(self, run_id):
        self._current_run_id = run_id

    @property
    def current_run_id(self):
        return self._current_run_id

    @property
    def user_agent(self):
        return 'W&B Internal Client %s' % __version__

    @property
    def api_key(self):
        auth = requests.utils.get_netrc_auth(self.api_url)
        key = None
        if auth:
            key = auth[-1]
        # Environment should take precedence
        if self._environ.get(env.API_KEY):
            key = self._environ.get(env.API_KEY)
        return key

    @property
    def api_url(self):
        return self.settings('base_url')

    @property
    def app_url(self):
        api_url = self.api_url
        # Development
        if api_url.endswith('.test') or self.settings().get("dev_prod"):
            return 'http://app.test'
        # On-prem VM
        if api_url.endswith(':11001'):
            return api_url.replace(':11001', ':11000')
        # Normal
        if api_url.startswith('https://api.'):
            return api_url.replace('api.', 'app.')
        # Unexpected
        return api_url

    def settings(self, key=None, section=None):
        """The settings overridden from the wandb/settings file.

        Args:
            key (str, optional): If provided only this setting is returned
            section (str, optional): If provided this section of the setting file is
            used, defaults to "default"

        Returns:
            A dict with the current settings

                {
                    "entity": "models",
                    "base_url": "https://api.wandb.ai",
                    "project": None
                }
        """
        result = self.default_settings.copy()
        result.update(self._settings.items(section=section))
        result.update({
            'entity':
            env.get_entity(self._settings.get(Settings.DEFAULT_SECTION,
                                              "entity",
                                              fallback=result.get('entity')),
                           env=self._environ),
            'project':
            env.get_project(self._settings.get(Settings.DEFAULT_SECTION,
                                               "project",
                                               fallback=result.get('project')),
                            env=self._environ),
            'base_url':
            env.get_base_url(self._settings.get(
                Settings.DEFAULT_SECTION,
                "base_url",
                fallback=result.get('base_url')),
                             env=self._environ),
            'ignore_globs':
            env.get_ignore(self._settings.get(
                Settings.DEFAULT_SECTION,
                "ignore_globs",
                fallback=result.get('ignore_globs')),
                           env=self._environ),
        })

        return result if key is None else result[key]

    def clear_setting(self, key):
        self._settings.clear(Settings.DEFAULT_SECTION, key)

    def set_setting(self, key, value, globally=False):
        self._settings.set(Settings.DEFAULT_SECTION,
                           key,
                           value,
                           globally=globally)
        if key == 'entity':
            env.set_entity(value, env=self._environ)
        elif key == 'project':
            env.set_project(value, env=self._environ)

    def parse_slug(self, slug, project=None, run=None):
        if slug and "/" in slug:
            parts = slug.split("/")
            project = parts[0]
            run = parts[1]
        else:
            project = project or self.settings().get("project")
            if project is None:
                raise CommError("No default project configured.")
            run = run or slug or env.get_run(env=self._environ)
            if run is None:
                run = "latest"
        return (project, run)

    @normalize_exceptions
    def viewer(self):
        query = gql('''
        query Viewer{
            viewer {
                id
                entity
                teams {
                    edges {
                        node {
                            name
                        }
                    }
                }
            }
        }
        ''')
        res = self.gql(query)
        return res.get('viewer') or {}

    @normalize_exceptions
    def list_projects(self, entity=None):
        """Lists projects in W&B scoped by entity.

        Args:
            entity (str, optional): The entity to scope this project to.

        Returns:
                [{"id","name","description"}]
        """
        query = gql('''
        query Models($entity: String!) {
            models(first: 10, entityName: $entity) {
                edges {
                    node {
                        id
                        name
                        description
                    }
                }
            }
        }
        ''')
        return self._flatten_edges(
            self.gql(
                query,
                variable_values={'entity': entity
                                 or self.settings('entity')})['models'])

    @normalize_exceptions
    def project(self, project, entity=None):
        """Retrive project

        Args:
            project (str): The project to get details for
            entity (str, optional): The entity to scope this project to.

        Returns:
                [{"id","name","repo","dockerImage","description"}]
        """
        query = gql('''
        query Models($entity: String, $project: String!) {
            model(name: $project, entityName: $entity) {
                id
                name
                repo
                dockerImage
                description
            }
        }
        ''')
        return self.gql(query,
                        variable_values={
                            'entity': entity,
                            'project': project
                        })['model']

    @normalize_exceptions
    def sweep(self, sweep, specs, project=None, entity=None):
        """Retrieve sweep.

        Args:
            sweep (str): The sweep to get details for
            specs (str): history specs
            project (str, optional): The project to scope this sweep to.
            entity (str, optional): The entity to scope this sweep to.

        Returns:
                [{"id","name","repo","dockerImage","description"}]
        """
        query = gql('''
        query Models($entity: String, $project: String!, $sweep: String!, $specs: [JSONString!]!) {
            model(name: $project, entityName: $entity) {
                sweep(sweepName: $sweep) {
                    id
                    name
                    method
                    state
                    description
                    config
                    createdAt
                    heartbeatAt
                    updatedAt
                    earlyStopJobRunning
                    bestLoss
                    controller
                    scheduler
                    runs {
                        edges {
                            node {
                                name
                                state
                                config
                                exitcode
                                heartbeatAt
                                shouldStop
                                failed
                                stopped
                                running
                                summaryMetrics
                                sampledHistory(specs: $specs)
                            }
                        }
                    }
                }
            }
        }
        ''')
        data = self.gql(query,
                        variable_values={
                            'entity': entity or self.settings('entity'),
                            'project': project or self.settings('project'),
                            'sweep': sweep,
                            'specs': specs
                        })['model']['sweep']
        if data:
            data['runs'] = self._flatten_edges(data['runs'])
        return data

    @normalize_exceptions
    def list_runs(self, project, entity=None):
        """Lists runs in W&B scoped by project.

        Args:
            project (str): The project to scope the runs to
            entity (str, optional): The entity to scope this project to.  Defaults to public models

        Returns:
                [{"id",name","description"}]
        """
        query = gql('''
        query Buckets($model: String!, $entity: String!) {
            model(name: $model, entityName: $entity) {
                buckets(first: 10) {
                    edges {
                        node {
                            id
                            name
                            displayName
                            description
                        }
                    }
                }
            }
        }
        ''')
        return self._flatten_edges(
            self.gql(query,
                     variable_values={
                         'entity': entity or self.settings('entity'),
                         'model': project or self.settings('project')
                     })['model']['buckets'])

    @normalize_exceptions
    def launch_run(self, command, project=None, entity=None, run_id=None):
        """Launch a run in the cloud.

        Args:
            command (str): The command to run
            program (str): The file to run
            project (str): The project to scope the runs to
            entity (str, optional): The entity to scope this project to.  Defaults to public models
            run_id (str, optional): The run_id to scope to

        Returns:
                [{"podName","status"}]
        """
        query = gql('''
        mutation launchRun(
            $entity: String
            $model: String
            $runId: String
            $image: String
            $command: String
            $patch: String
            $cwd: String
            $datasets: [String]
        ) {
            launchRun(input: {id: $runId, entityName: $entity, patch: $patch, modelName: $model,
                image: $image, command: $command, datasets: $datasets, cwd: $cwd}) {
                podName
                status
                runId
            }
        }
        ''')
        patch = BytesIO()
        if self.git.dirty:
            self.git.repo.git.execute(['git', 'diff'], output_stream=patch)
            patch.seek(0)
        cwd = "."
        if self.git.enabled:
            cwd = cwd + os.getcwd().replace(self.git.repo.working_dir, "")
        return self.gql(query,
                        variable_values={
                            'entity': entity or self.settings('entity'),
                            'model': project or self.settings('project'),
                            'command': command,
                            'runId': run_id,
                            'patch': patch.read().decode("utf8"),
                            'cwd': cwd
                        })

    @normalize_exceptions
    def run_config(self, project, run=None, entity=None):
        """Get the relevant configs for a run

        Args:
            project (str): The project to download, (can include bucket)
            run (str, optional): The run to download
            entity (str, optional): The entity to scope this project to.
        """
        query = gql('''
        query Model($name: String!, $entity: String!, $run: String!) {
            model(name: $name, entityName: $entity) {
                bucket(name: $run) {
                    config
                    commit
                    patch
                    files(names: ["wandb-metadata.json"]) {
                        edges {
                            node {
                                url
                            }
                        }
                    }
                }
            }
        }
        ''')

        response = self.gql(query,
                            variable_values={
                                'name': project,
                                'run': run,
                                'entity': entity
                            })
        if response['model'] == None:
            raise ValueError("Run {}/{}/{} not found".format(
                entity, project, run))
        run = response['model']['bucket']
        commit = run['commit']
        patch = run['patch']
        config = json.loads(run['config'] or '{}')
        if len(run['files']['edges']) > 0:
            url = run['files']['edges'][0]['node']['url']
            res = requests.get(url)
            res.raise_for_status()
            metadata = res.json()
        else:
            metadata = {}
        return (commit, config, patch, metadata)

    @normalize_exceptions
    def run_resume_status(self, entity, project_name, name):
        """Check if a run exists and get resume information.

        Args:
            entity (str, optional): The entity to scope this project to.
            project_name (str): The project to download, (can include bucket)
            run (str, optional): The run to download
        """
        query = gql('''
        query Model($project: String!, $entity: String, $name: String!) {
            model(name: $project, entityName: $entity) {
                id
                name
                entity {
                    id
                    name
                }

                bucket(name: $name, missingOk: true) {
                    id
                    name
                    summaryMetrics
                    displayName
                    logLineCount
                    historyLineCount
                    eventsLineCount
                    historyTail
                    eventsTail
                }
            }
        }
        ''')

        response = self.gql(query,
                            variable_values={
                                'entity': entity,
                                'project': project_name,
                                'name': name,
                            })

        if 'model' not in response or 'bucket' not in response['model']:
            return None

        project = response['model']
        self.set_setting('project', project_name)
        if 'entity' in project:
            self.set_setting('entity', project['entity']['name'])

        return project['bucket']

    @normalize_exceptions
    def check_stop_requested(self, project_name, entity_name, run_id):
        query = gql('''
        query Model($projectName: String, $entityName: String, $runId: String!) {
            project(name:$projectName, entityName:$entityName) {
                run(name:$runId) {
                    stopped
                }
            }
        }
        ''')

        response = self.gql(query,
                            variable_values={
                                'projectName': project_name,
                                'entityName': entity_name,
                                'runId': run_id,
                            })

        project = response.get('project', None)
        if not project:
            return False
        run = project.get('run', None)
        if not run:
            return False

        return run['stopped']

    def format_project(self, project):
        return re.sub(r'\W+', '-', project.lower()).strip("-_")

    @normalize_exceptions
    def upsert_project(self, project, id=None, description=None, entity=None):
        """Create a new project

        Args:
            project (str): The project to create
            description (str, optional): A description of this project
            entity (str, optional): The entity to scope this project to.
        """
        mutation = gql('''
        mutation UpsertModel($name: String!, $id: String, $entity: String!, $description: String, $repo: String)  {
            upsertModel(input: { id: $id, name: $name, entityName: $entity, description: $description, repo: $repo }) {
                model {
                    name
                    description
                }
            }
        }
        ''')
        response = self.gql(mutation,
                            variable_values={
                                'name': self.format_project(project),
                                'entity': entity or self.settings('entity'),
                                'description': description,
                                'repo': self.git.remote_url,
                                'id': id
                            })
        return response['upsertModel']['model']

    @normalize_exceptions
    def upsert_run(self,
                   id=None,
                   name=None,
                   project=None,
                   host=None,
                   group=None,
                   tags=None,
                   config=None,
                   description=None,
                   entity=None,
                   state=None,
                   display_name=None,
                   notes=None,
                   repo=None,
                   job_type=None,
                   program_path=None,
                   commit=None,
                   sweep_name=None,
                   summary_metrics=None,
                   num_retries=None):
        """Update a run

        Args:
            id (str, optional): The existing run to update
            name (str, optional): The name of the run to create
            group (str, optional): Name of the group this run is a part of
            project (str, optional): The name of the project
            config (dict, optional): The latest config params
            description (str, optional): A description of this project
            entity (str, optional): The entity to scope this project to.
            repo (str, optional): Url of the program's repository.
            state (str, optional): State of the program.
            job_type (str, optional): Type of job, e.g 'train'.
            program_path (str, optional): Path to the program.
            commit (str, optional): The Git SHA to associate the run with
            summary_metrics (str, optional): The JSON summary metrics
        """
        mutation = gql('''
        mutation UpsertBucket(
            $id: String, $name: String,
            $project: String,
            $entity: String!,
            $groupName: String,
            $description: String,
            $displayName: String,
            $notes: String,
            $commit: String,
            $config: JSONString,
            $host: String,
            $debug: Boolean,
            $program: String,
            $repo: String,
            $jobType: String,
            $state: String,
            $sweep: String,
            $tags: [String!],
            $summaryMetrics: JSONString,
        ) {
            upsertBucket(input: {
                id: $id,
                name: $name,
                groupName: $groupName,
                modelName: $project,
                entityName: $entity,
                description: $description,
                displayName: $displayName,
                notes: $notes,
                config: $config,
                commit: $commit,
                host: $host,
                debug: $debug,
                jobProgram: $program,
                jobRepo: $repo,
                jobType: $jobType,
                state: $state,
                sweep: $sweep,
                tags: $tags,
                summaryMetrics: $summaryMetrics,
            }) {
                bucket {
                    id
                    name
                    displayName
                    description
                    config
                    project {
                        id
                        name
                        entity {
                            id
                            name
                        }
                    }
                }
            }
        }
        ''')
        if config is not None:
            config = json.dumps(config)
        if not description or description.isspace():
            description = None

        kwargs = {}
        if num_retries is not None:
            kwargs['num_retries'] = num_retries

        variable_values = {
            'id': id,
            'entity': entity or self.settings('entity'),
            'name': name,
            'project': project,
            'groupName': group,
            'tags': tags,
            'description': description,
            'config': config,
            'commit': commit,
            'displayName': display_name,
            'notes': notes,
            'host': host,
            'debug': env.is_debug(env=self._environ),
            'repo': repo,
            'program': program_path,
            'jobType': job_type,
            'state': state,
            'sweep': sweep_name,
            'summaryMetrics': summary_metrics
        }

        response = self.gql(mutation,
                            variable_values=variable_values,
                            **kwargs)

        run = response['upsertBucket']['bucket']
        project = run.get('project')
        if project:
            self.set_setting('project', project['name'])
            entity = project.get('entity')
            if entity:
                self.set_setting('entity', entity['name'])

        return response['upsertBucket']['bucket']

    @normalize_exceptions
    def upload_urls(self,
                    project,
                    files,
                    run=None,
                    entity=None,
                    description=None):
        """Generate temporary resumeable upload urls

        Args:
            project (str): The project to download
            files (list or dict): The filenames to upload
            run (str, optional): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            (bucket_id, file_info)
            bucket_id: id of bucket we uploaded to
            file_info: A dict of filenames and urls, also indicates if this revision already has uploaded files.
                {
                    'weights.h5': { "url": "https://weights.url" },
                    'model.json': { "url": "https://model.json", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
                }
        """
        query = gql('''
        query Model($name: String!, $files: [String]!, $entity: String!, $run: String!, $description: String) {
            model(name: $name, entityName: $entity) {
                bucket(name: $run, desc: $description) {
                    id
                    files(names: $files) {
                        edges {
                            node {
                                name
                                url(upload: true)
                                updatedAt
                            }
                        }
                    }
                }
            }
        }
        ''')
        run_id = run or self.settings('run')
        entity = entity or self.settings('entity')
        query_result = self.gql(query,
                                variable_values={
                                    'name': project,
                                    'run': run_id,
                                    'entity': entity,
                                    'description': description,
                                    'files': [file for file in files]
                                })

        run = query_result['model']['bucket']
        if run:
            result = {
                file['name']: file
                for file in self._flatten_edges(run['files'])
            }
            return run['id'], result
        else:
            raise CommError("Run does not exist {}/{}/{}.".format(
                entity, project, run_id))

    @normalize_exceptions
    def download_urls(self, project, run=None, entity=None):
        """Generate download urls

        Args:
            project (str): The project to download
            run (str, optional): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            A dict of extensions and urls

                {
                    'weights.h5': { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' },
                    'model.json': { "url": "https://model.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }
                }
        """
        query = gql('''
        query Model($name: String!, $entity: String!, $run: String!)  {
            model(name: $name, entityName: $entity) {
                bucket(name: $run) {
                    files {
                        edges {
                            node {
                                name
                                url
                                md5
                                updatedAt
                            }
                        }
                    }
                }
            }
        }
        ''')
        query_result = self.gql(query,
                                variable_values={
                                    'name': project,
                                    'run': run or self.settings('run'),
                                    'entity': entity or self.settings('entity')
                                })
        files = self._flatten_edges(query_result['model']['bucket']['files'])
        return {file['name']: file for file in files if file}

    @normalize_exceptions
    def download_url(self, project, file_name, run=None, entity=None):
        """Generate download urls

        Args:
            project (str): The project to download
            file_name (str): The name of the file to download
            run (str, optional): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            A dict of extensions and urls

                { "url": "https://weights.url", "updatedAt": '2013-04-26T22:22:23.832Z', 'md5': 'mZFLkyvTelC5g8XnyQrpOw==' }

        """
        query = gql('''
        query Model($name: String!, $fileName: String!, $entity: String!, $run: String!)  {
            model(name: $name, entityName: $entity) {
                bucket(name: $run) {
                    files(names: [$fileName]) {
                        edges {
                            node {
                                name
                                url
                                md5
                                updatedAt
                            }
                        }
                    }
                }
            }
        }
        ''')
        query_result = self.gql(query,
                                variable_values={
                                    'name': project,
                                    'run': run or self.settings('run'),
                                    'fileName': file_name,
                                    'entity': entity or self.settings('entity')
                                })
        if query_result['model']:
            files = self._flatten_edges(
                query_result['model']['bucket']['files'])
            return files[0] if len(files) > 0 and files[0].get(
                'updatedAt') else None
        else:
            return None

    @normalize_exceptions
    def download_file(self, url):
        """Initiate a streaming download

        Args:
            url (str): The url to download

        Returns:
            A tuple of the content length and the streaming response
        """
        response = requests.get(url, stream=True)
        response.raise_for_status()
        return (int(response.headers.get('content-length', 0)), response)

    @normalize_exceptions
    def download_write_file(self, metadata, out_dir=None):
        """Download a file from a run and write it to wandb/

        Args:
            metadata (obj): The metadata object for the file to download. Comes from Api.download_urls().

        Returns:
            A tuple of the file's local path and the streaming response. The streaming response is None if the file already existed and was up to date.
        """
        fileName = metadata['name']
        path = os.path.join(out_dir or wandb_dir(), fileName)
        if self.file_current(fileName, metadata['md5']):
            return path, None

        size, response = self.download_file(metadata['url'])

        with open(path, "wb") as file:
            for data in response.iter_content(chunk_size=1024):
                file.write(data)

        return path, response

    def upload_file(self, url, file, callback=None, extra_headers={}):
        """Uploads a file to W&B with failure resumption

        Args:
            url (str): The url to download
            file (str): The path to the file you want to upload
            callback (:obj:`func`, optional): A callback which is passed the number of
            bytes uploaded since the last time it was called, used to report progress

        Returns:
            The requests library response object
        """
        extra_headers = extra_headers.copy()
        response = None
        progress = Progress(file, callback=callback)
        if progress.len == 0:
            raise CommError("%s is an empty file" % file.name)
        try:
            response = requests.put(url, data=progress, headers=extra_headers)
            response.raise_for_status()
        except requests.exceptions.RequestException as e:
            status_code = e.response.status_code if e.response != None else 0
            # Retry errors from cloud storage or local network issues
            if status_code in (308, 409, 429, 500,
                               502, 503, 504) or isinstance(
                                   e, (requests.exceptions.Timeout,
                                       requests.exceptions.ConnectionError)):
                util.sentry_reraise(retry.TransientException(exc=e))
            else:
                util.sentry_reraise(e)

        return response

    upload_file_retry = normalize_exceptions(
        retry.retriable(num_retries=5)(upload_file))

    @normalize_exceptions
    def register_agent(self, host, sweep_id=None, project_name=None):
        """Register a new agent

        Args:
            host (str): hostname
            persistent (bool): long running or oneoff
            sweep (str): sweep id
            project_name: (str): model that contains sweep
        """
        mutation = gql('''
        mutation CreateAgent(
            $host: String!
            $projectName: String!,
            $entityName: String!,
            $sweep: String!
        ) {
            createAgent(input: {
                host: $host,
                projectName: $projectName,
                entityName: $entityName,
                sweep: $sweep,
            }) {
                agent {
                    id
                }
            }
        }
        ''')
        if project_name is None:
            project_name = self.settings('project')

        # don't retry on validation errors
        def no_retry_400(e):
            if not isinstance(e, requests.HTTPError):
                return True
            if e.response.status_code != 400:
                return True
            body = json.loads(e.response.content)
            raise UsageError(body['errors'][0]['message'])

        response = self.gql(mutation,
                            variable_values={
                                'host': host,
                                'entityName': self.settings("entity"),
                                'projectName': project_name,
                                'sweep': sweep_id
                            },
                            check_retry_fn=no_retry_400)
        return response['createAgent']['agent']

    def agent_heartbeat(self, agent_id, metrics, run_states):
        """Notify server about agent state, receive commands.

        Args:
            agent_id (str): agent_id
            metrics (dict): system metrics
            run_states (dict): run_id: state mapping
        Returns:
            List of commands to execute.
        """
        mutation = gql('''
        mutation Heartbeat(
            $id: ID!,
            $metrics: JSONString,
            $runState: JSONString
        ) {
            agentHeartbeat(input: {
                id: $id,
                metrics: $metrics,
                runState: $runState
            }) {
                agent {
                    id
                }
                commands
            }
        }
        ''')
        try:
            response = self.gql(mutation,
                                variable_values={
                                    'id': agent_id,
                                    'metrics': json.dumps(metrics),
                                    'runState': json.dumps(run_states)
                                })
        except Exception as e:
            # GQL raises exceptions with stringified python dictionaries :/
            message = ast.literal_eval(e.args[0])["message"]
            logger.error('Error communicating with W&B: %s', message)
            return []
        else:
            return json.loads(response['agentHeartbeat']['commands'])

    @normalize_exceptions
    def upsert_sweep(self,
                     config,
                     controller=None,
                     scheduler=None,
                     obj_id=None):
        """Upsert a sweep object.

        Args:
            config (str): sweep config (will be converted to yaml)
        """
        mutation = gql('''
        mutation UpsertSweep(
            $id: ID,
            $config: String,
            $description: String,
            $entityName: String!,
            $projectName: String!,
            $controller: JSONString,
            $scheduler: JSONString
        ) {
            upsertSweep(input: {
                id: $id,
                config: $config,
                description: $description,
                entityName: $entityName,
                projectName: $projectName,
                controller: $controller,
                scheduler: $scheduler
            }) {
                sweep {
                    name
                }
            }
        }
        ''')

        # don't retry on validation errors
        # TODO(jhr): generalize error handling routines
        def no_retry_400_or_404(e):
            if not isinstance(e, requests.HTTPError):
                return True
            if e.response.status_code != 400 and e.response.status_code != 404:
                return True
            body = json.loads(e.response.content)
            raise UsageError(body['errors'][0]['message'])

        response = self.gql(mutation,
                            variable_values={
                                'id': obj_id,
                                'config': yaml.dump(config),
                                'description': config.get("description"),
                                'entityName': self.settings("entity"),
                                'projectName': self.settings("project"),
                                'controller': controller,
                                'scheduler': scheduler
                            },
                            check_retry_fn=no_retry_400_or_404)
        return response['upsertSweep']['sweep']['name']

    @normalize_exceptions
    def create_anonymous_api_key(self):
        """Creates a new API key belonging to a new anonymous user."""
        mutation = gql('''
        mutation CreateAnonymousApiKey {
            createAnonymousEntity(input: {}) {
                apiKey {
                    name
                }
            }
        }
        ''')

        response = self.gql(mutation, variable_values={})
        return response['createAnonymousEntity']['apiKey']['name']

    def file_current(self, fname, md5):
        """Checksum a file and compare the md5 with the known md5
        """
        return os.path.isfile(fname) and util.md5_file(fname) == md5

    @normalize_exceptions
    def pull(self, project, run=None, entity=None):
        """Download files from W&B

        Args:
            project (str): The project to download
            run (str, optional): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models

        Returns:
            The requests library response object
        """
        project, run = self.parse_slug(project, run=run)
        urls = self.download_urls(project, run, entity)
        responses = []
        for fileName in urls:
            _, response = self.download_write_file(urls[fileName])
            if response:
                responses.append(response)

        return responses

    def get_project(self):
        return self.settings('project')

    @normalize_exceptions
    def push(self,
             files,
             run=None,
             entity=None,
             project=None,
             description=None,
             force=True,
             progress=False):
        """Uploads multiple files to W&B

        Args:
            files (list or dict): The filenames to upload
            run (str, optional): The run to upload to
            entity (str, optional): The entity to scope this project to.  Defaults to wandb models
            project (str, optional): The name of the project to upload to. Defaults to the one in settings.
            description (str, optional): The description of the changes
            force (bool, optional): Whether to prevent push if git has uncommitted changes
            progress (callable, or stream): If callable, will be called with (chunk_bytes,
                total_bytes) as argument else if True, renders a progress bar to stream.

        Returns:
            The requests library response object
        """
        if project is None:
            project = self.get_project()
        if project is None:
            raise CommError("No project configured.")
        if run is None:
            run = self.current_run_id

        # TODO(adrian): we use a retriable version of self.upload_file() so
        # will never retry self.upload_urls() here. Instead, maybe we should
        # make push itself retriable.
        run_id, result = self.upload_urls(project, files, run, entity,
                                          description)
        responses = []
        for file_name, file_info in result.items():
            file_url = file_info['url']

            # If the upload URL is relative, fill it in with the base URL,
            # since its a proxied file store like the on-prem VM.
            if file_url.startswith('/'):
                file_url = '{}{}'.format(self.api_url, file_url)

            try:
                # To handle Windows paths
                # TODO: this doesn't handle absolute paths...
                normal_name = os.path.join(*file_name.split("/"))
                open_file = files[file_name] if isinstance(
                    files, dict) else open(normal_name, "rb")
            except IOError:
                print("%s does not exist" % file_name)
                continue
            if progress:
                if hasattr(progress, '__call__'):
                    responses.append(
                        self.upload_file_retry(file_url, open_file, progress))
                else:
                    length = os.fstat(open_file.fileno()).st_size
                    with click.progressbar(
                            file=progress,
                            length=length,
                            label='Uploading file: %s' % (file_name),
                            fill_char=click.style('&', fg='green')) as bar:
                        responses.append(
                            self.upload_file_retry(
                                file_url, open_file,
                                lambda bites, _: bar.update(bites)))
            else:
                responses.append(
                    self.upload_file_retry(file_info['url'], open_file))
            open_file.close()
        return responses

    def get_file_stream_api(self):
        """This creates a new file pusher thread.  Call start to initiate the thread that talks to W&B"""
        if not self._file_stream_api:
            if self._current_run_id is None:
                raise UsageError(
                    'Must have a current run to use file stream API.')
            self._file_stream_api = FileStreamApi(self, self._current_run_id)
        return self._file_stream_api

    def _status_request(self, url, length):
        """Ask google how much we've uploaded"""
        return requests.put(url=url,
                            headers={
                                'Content-Length': '0',
                                'Content-Range': 'bytes */%i' % length
                            })

    def _flatten_edges(self, response):
        """Return an array from the nested graphql relay structure"""
        return [node['node'] for node in response['edges']]
Пример #26
0
class APIClient(object):
    def __new__(cls, baseUrl, token):
        if baseUrl is None or token is None:
            raise ValueError
        else:
            return super(APIClient, cls).__new__(cls)

    def __init__(self, baseUrl, token):
        self.baseUrl = baseUrl
        self.headers = {
            'Authorization': 'Bearer %s' % token,
        }
        transport = RequestsHTTPTransport(baseUrl,
                                          headers=self.headers,
                                          use_json=True)
        self.client = Client(transport=transport,
                             fetch_schema_from_transport=True)
        logging.debug(
            "Api Client initialized with baseUrl: {} and token: {}".format(
                baseUrl, token))

    def get_assets_for_recording(self, recording_id, asset_type):
        """Query to retrieve filtered assets for a recording by a specific asset type."""
        logging.debug("Getting {} assets for a recording {} ".format(
            asset_type, recording_id))
        query = gql('''
                    query{
                      temporalDataObject(id:"%s"){
                        assets(type:"%s") {
                          records  {
                            id
                            contentType
                            createdDateTime
                            signedUri
                          }
                        }
                      }
                    }
                    ''' % (recording_id, asset_type))
        try:
            response = self.client.execute(query)
            return response['temporalDataObject']['assets']['records']

        except Exception as e:
            logging.error(
                'Failed to find {} for recording_id {} due to: {}'.format(
                    asset_type, recording_id, e))
        return None

    def publish_results(self, recording_id, results):
        """Performs a Multipart/form-data request with the graphql query and the output file"""
        logging.debug(
            "Publishing results for recording: {} . Results: {}".format(
                recording_id, results))
        filename = 'tmpfile'

        query = '''
            mutation {
              createAsset(
                input: {
                    containerId: "%s",
                    contentType: "application/json",
                    assetType: "object"
                }) {
                id
                uri
              }
            }
            ''' % recording_id

        data = {'query': query, 'filename': filename}

        files = {'file': (filename, json.dumps(results))}

        try:
            response = requests.post(self.baseUrl,
                                     data=data,
                                     files=files,
                                     headers=self.headers)
            return response.status_code == HTTPStatus.OK
        except Exception as e:
            logging.error(
                'Failed to create asset for recording: {} due to: {}'.format(
                    recording_id, e))
            return False

    def update_task(self, job_id, task_id, status, output=None):
        logging.debug("Updating task status to {} for task_id: {}".format(
            status, task_id))
        if status not in VALID_TASK_STATUS:
            return False

        if output is None:
            output = {}

        query = gql(
            '''
            mutation {
              updateTask(input: {id: "%s", jobId: "%s", status: %s, outputString: "%s"}) {
                id
                status
              }
            }
        ''' %
            (task_id, job_id, status, json.dumps(output).replace('"', '\\"')))
        try:
            self.client.execute(query)
        except Exception as e:
            logging.error(
                'Failed to update task {} status to {} due to: {}'.format(
                    task_id, status, e))
            return False
Пример #27
0
    def get_vars(self, loader, path, entities, cache=True):
        ''' parses the inventory file '''

        if not isinstance(entities, list):
            entities = [entities]

        super(VarsModule, self).get_vars(loader, path, entities)

        if 'api_server' not in OPTIONS.keys():
            self.loader = loader
            config_file_path = path + "/graphql_plugin.yaml"
            self.display.v('Load vars plugin configuration file {}'.format(
                config_file_path))

            if self.verify_file(config_file_path):
                self.parse_config_file(config_file_path)
            else:
                return {}

        self.api_server = OPTIONS['api_server']
        self.api_token = OPTIONS['api_token']

        data = {}
        for entity in entities:
            if isinstance(entity, Host):
                subdir = 'host_vars'
            elif isinstance(entity, Group):
                subdir = 'group_vars'
            else:
                raise AnsibleParserError(
                    "Supplied entity must be Host or Group, got %s instead" %
                    (type(entity)))

            if isinstance(entity, Group):
                key = "Group_%s" % entity.name
                if cache and key in FOUND:
                    self.display.v('Load vars from cache')
                    new_data = FOUND[key]
                else:
                    self.display.v('Load vars from graphql api {}'.format(
                        self.api_server))
                    try:

                        # Select your transport with a defined url endpoint
                        transport = AIOHTTPTransport(
                            url="https://{}/graphql".format(self.api_server))

                        # Create a GraphQL client using the defined transport
                        client = Client(transport=transport,
                                        fetch_schema_from_transport=True)

                        # Provide a GraphQL query
                        query = gql('''
                            query {
                                groupByName(groupName: "''' + entity.name +
                                    '''")
                                {
                                    ansible_group_name
                                    variables
                                }
                            }
                        ''')

                        # Execute the query on the transport
                        new_data = client.execute(query)

                        if new_data["groupByName"] != None:
                            new_data = new_data["groupByName"]["variables"]
                        else:
                            new_data = {}

                        FOUND[key] = new_data

                    except Exception as e:
                        raise AnsibleParserError(e)

                data = combine_vars(data, new_data)

        return data
Пример #28
0
class GraphQLClient(object):
    CONNECT_TIMEOUT = 15  # [sec]
    RETRY_DELAY = 10  # [sec]
    MAX_RETRIES = 3  # [-]

    class Decorators(object):
        @staticmethod
        def autoConnectingClient(wrappedMethod):
            def wrapper(obj, *args, **kwargs):
                for retry in range(GraphQLClient.MAX_RETRIES):
                    try:
                        return wrappedMethod(obj, *args, **kwargs)
                    except Exception:
                        pass
                    try:
                        obj._logger.warning(
                            '(Re)connecting to GraphQL service.')
                        obj.reconnect()
                    except ConnectionRefusedError:
                        obj._logger.warn(
                            'Connection refused. Retry in 10s.'.format(
                                GraphQLClient.RETRY_DELAY))
                        time.sleep(GraphQLClient.RETRY_DELAY)
                else:  # So the exception is exposed.
                    obj.reconnect()
                    return wrappedMethod(obj, *args, **kwargs)

            return wrapper

    def __init__(self, serverUrl):
        self._logger = logging.getLogger(self.__class__.__name__)
        self.connect(serverUrl.geturl())

    def __enter__(self):
        self.connect(serverUrl.geturl())
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self._client = None

    def connect(self, url):
        host = url.split('//')[1].split('/')[0]
        request = requests.get(url,
                               headers={
                                   'Host': str(host),
                                   'Accept': 'text/html',
                               })
        request.raise_for_status()
        csrf = request.cookies['csrftoken']
        self._client = Client(transport=RequestsHTTPTransport(
            url=url,
            cookies={"csrftoken": csrf},
            headers={'x-csrftoken': csrf}),
                              fetch_schema_from_transport=True)

    def disconnect(self):
        self._client = None

    def reconnect(self):
        self.disconnect()
        self.connect(serverUrl.geturl())

    @Decorators.autoConnectingClient
    def execute_query(self, querytext):
        query = gql(querytext)
        return self._client.execute(query)
Пример #29
0
class SpeckleClient:
    """
    The `SpeckleClient` is your entry point for interacting with your Speckle Server's GraphQL API.
    You'll need to have access to a server to use it, or you can use our public server `speckle.xyz`.

    To authenticate the client, you'll need to have downloaded the [Speckle Manager](https://speckle.guide/#speckle-manager)
    and added your account.

    ```py
    from specklepy.api.client import SpeckleClient
    from specklepy.api.credentials import get_default_account

    # initialise the client
    client = SpeckleClient(host="speckle.xyz") # or whatever your host is
    # client = SpeckleClient(host="localhost:3000", use_ssl=False) or use local server

    # authenticate the client with a token (account has been added in Speckle Manager)
    account = get_default_account()
    client.authenticate(token=account.token)

    # create a new stream. this returns the stream id
    new_stream_id = client.stream.create(name="a shiny new stream")

    # use that stream id to get the stream from the server
    new_stream = client.stream.get(id=new_stream_id)
    ```
    """

    DEFAULT_HOST = "speckle.xyz"
    USE_SSL = True

    def __init__(self,
                 host: str = DEFAULT_HOST,
                 use_ssl: bool = USE_SSL) -> None:
        ws_protocol = "ws"
        http_protocol = "http"

        if use_ssl:
            ws_protocol = "wss"
            http_protocol = "https"

        # sanitise host input by removing protocol and trailing slash
        host = re.sub(r"((^\w+:|^)\/\/)|(\/$)", "", host)

        self.url = f"{http_protocol}://{host}"
        self.graphql = self.url + "/graphql"
        self.ws_url = f"{ws_protocol}://{host}/graphql"
        self.me = None

        self.httpclient = Client(transport=RequestsHTTPTransport(
            url=self.graphql, verify=True, retries=3))
        self.wsclient = None

        self._init_resources()

        # Check compatibility with the server
        try:
            serverInfo = self.server.get()
            if not isinstance(serverInfo, ServerInfo):
                raise Exception("Couldn't get ServerInfo")
        except Exception as ex:
            raise SpeckleException(
                f"{self.url} is not a compatible Speckle Server", ex)

    def __repr__(self):
        return (
            f"SpeckleClient( server: {self.url}, authenticated: {self.me is not None} )"
        )

    def authenticate(self, token: str) -> None:
        """Authenticate the client using a personal access token
        The token is saved in the client object and a synchronous GraphQL entrypoint is created

        Arguments:
            token {str} -- an api token
        """
        self.me = {"token": token}
        headers = {
            "Authorization": f"Bearer {self.me['token']}",
            "Content-Type": "application/json",
        }
        httptransport = RequestsHTTPTransport(url=self.graphql,
                                              headers=headers,
                                              verify=True,
                                              retries=3)
        wstransport = WebsocketsTransport(
            url=self.ws_url,
            init_payload={"Authorization": f"Bearer {self.me['token']}"},
        )
        self.httpclient = Client(transport=httptransport)
        self.wsclient = Client(transport=wstransport)

        self._init_resources()

    def execute_query(self, query: str) -> Dict:
        return self.httpclient.execute(query)

    def _init_resources(self) -> None:
        self.stream = stream.Resource(me=self.me,
                                      basepath=self.url,
                                      client=self.httpclient)
        self.commit = commit.Resource(me=self.me,
                                      basepath=self.url,
                                      client=self.httpclient)
        self.branch = branch.Resource(me=self.me,
                                      basepath=self.url,
                                      client=self.httpclient)
        self.object = object.Resource(me=self.me,
                                      basepath=self.url,
                                      client=self.httpclient)
        self.server = server.Resource(me=self.me,
                                      basepath=self.url,
                                      client=self.httpclient)
        self.user = user.Resource(me=self.me,
                                  basepath=self.url,
                                  client=self.httpclient)
        self.subscribe = subscriptions.Resource(
            me=self.me,
            basepath=self.ws_url,
            client=self.wsclient,
        )

    def __getattr__(self, name):
        try:
            attr = getattr(resources, name)
            return attr.Resource(me=self.me,
                                 basepath=self.url,
                                 client=self.httpclient)
        except:
            raise SpeckleException(
                f"Method {name} is not supported by the SpeckleClient class")
Пример #30
0
class NewRelicGQL(object):
    def __init__(self, account_id, api_key, region="us"):
        try:
            self.account_id = int(account_id)
        except ValueError:
            raise ValueError("Account ID must be an integer")

        self.api_key = api_key

        if region == "us":
            self.url = "https://api.newrelic.com/graphql"
        elif region == "eu":
            self.url = "https://api.eu.newrelic.com/graphql"
        else:
            raise ValueError("Region must be one of 'us' or 'eu'")

        transport = RequestsHTTPTransport(url=self.url, use_json=True)
        transport.headers = {"api-key": self.api_key}

        try:
            self.client = Client(transport=transport,
                                 fetch_schema_from_transport=True)
        except Exception:
            self.client = Client(transport=transport,
                                 fetch_schema_from_transport=False)

    def query(self, query, timeout=None, **variable_values):
        return self.client.execute(gql(query),
                                   timeout=timeout,
                                   variable_values=variable_values or None)

    def get_linked_accounts(self):
        """
        return a list of linked accounts for the New Relic account
        """
        res = self.query(
            """
            query ($accountId: Int!) {
              actor {
                account(id: $accountId) {
                  cloud {
                    linkedAccounts {
                      id
                      name
                      createdAt
                      updatedAt
                      authLabel
                      externalId
                    }
                  }
                }
              }
            }
            """,
            accountId=self.account_id,
        )
        return res["actor"]["account"]["cloud"]["linkedAccounts"]

    def get_license_key(self):
        """
        Fetch the license key for the NR Account
        """
        res = self.query(
            """
            query ($accountId: Int!) {
              requestContext {
                apiKey
              }
              actor {
                account(id: $accountId) {
                  licenseKey
                  id
                  name
                }
              }
            }
            """,
            accountId=self.account_id,
        )
        return res["actor"]["account"]["licenseKey"]

    def get_linked_account_by_name(self, account_name):
        """
        return a specific linked account of the New Relic account
        """
        accounts = self.get_linked_accounts()
        return next((a for a in accounts if a["name"] == account_name), None)

    def link_account(self, role_arn, account_name):
        """
        create a linked account (cloud integrations account)
        in the New Relic account
        """
        res = self.query(
            """
            mutation ($accountId: Int!, $accounts: CloudLinkCloudAccountsInput!){
              cloudLinkAccount (accountId: $accountId, accounts: $accounts) {
                linkedAccounts {
                  id
                  name
                }
                errors {
                    message
                }
              }
            }
            """,
            accountId=self.account_id,
            accounts={"aws": {
                "arn": role_arn,
                "name": account_name
            }},
        )
        return res["cloudLinkAccount"]["linkedAccounts"][0]

    def unlink_account(self, linked_account_id):
        """
        Unlink a New Relic Cloud integrations account
        """
        return self.query(
            """
            mutation ($accountId: Int!, $accounts: CloudUnlinkCloudAccountsInput!) {
              cloudUnLinkAccount (accountId: $accountId, accounts: $accounts) {
                unlinkedAccounts {
                  id
                  name
                }
                errors {
                  type
                  message
                }
              }
            }
            """,
            accountId=self.account_id,
            accounts={"linkedAccountId": linked_account_id},
        )

    def get_integrations(self, linked_account_id):
        """
        returns the integrations for the linked account
        """
        res = self.query(
            """
            query ($accountId: Int!, $linkedAccountId: Int!) {
              actor {
                account (id: $accountId) {
                  cloud {
                    linkedAccount(id: $linkedAccountId) {
                      integrations {
                        id
                        name
                        createdAt
                        updatedAt
                        service {
                          slug
                          isEnabled
                        }
                      }
                    }
                  }
                }
              }
            }
            """,
            accountId=self.account_id,
            linkedAccountId=int(linked_account_id),
        )
        return res["actor"]["account"]["cloud"]["linkedAccount"][
            "integrations"]

    def get_integration_by_service_slug(self, linked_account_id, service_slug):
        integrations = self.get_integrations(linked_account_id)
        return next(
            (i for i in integrations if i["service"]["slug"] == service_slug),
            None)

    def is_integration_enabled(self, linked_account_id, service_slug):
        integration = self.get_integration_by_service_slug(
            linked_account_id, service_slug)
        return integration and integration["service"]["isEnabled"]

    def enable_integration(self, linked_account_id, provider_slug,
                           service_slug):
        """
        enable monitoring of a Cloud provider service (integration)
        """
        res = self.query(
            """
            mutation ($accountId: Int!, $integrations: CloudIntegrationsInput!) {
              cloudConfigureIntegration (
                accountId: $accountId,
                integrations: $integrations
              ) {
                integrations {
                  id
                  name
                  service {
                    id
                    name
                  }
                }
                errors {
                  linkedAccountId
                  message
                }
              }
            }
            """,
            accountId=self.account_id,
            integrations={
                provider_slug: {
                    service_slug: [{
                        "linkedAccountId": linked_account_id
                    }]
                }
            },
        )
        return res["cloudConfigureIntegration"]["integrations"][0]

    def disable_integration(self, linked_account_id, provider_slug,
                            service_slug):
        """
        Disable monitoring of a Cloud provider service (integration)
        """
        return self.query(
            """
            mutation ($accountId: Int!, $integrations: CloudIntegrationsInput!) {
              cloudDisableIntegration (
                accountId: $accountId,
                integrations: $integrations
              ) {
                disabledIntegrations {
                  id
                  accountId
                  name
                }
                errors {
                  type
                  message
                }
              }
            }
            """,
            accountId=self.account_id,
            integrations={
                provider_slug: {
                    service_slug: [{
                        "linkedAccountId": linked_account_id
                    }]
                }
            },
        )
class GraphqlMongodbController():
    def __init__(self):
        self.mongodbConnection = False
        self.graphqlConnection = False
        self.schema = object
        self.gqlclient = object
        self.mongodbclient = object

    def setEnvironmentVariable(self, environmentVariable):
        try:
            if os.getenv(environmentVariable) is not None:
                setattr(self, environmentVariable,
                        os.getenv(environmentVariable))
            else:
                print(f"{environmentVariable} is not set")
                return False
            return True
        except:
            return False

    def setMongodbClient(self):
        try:
            self.mongodbclient = connect("production_securethebox",
                                         host="mongodb+srv://" +
                                         os.environ["MONGODB_USER"] + ":" +
                                         os.environ["MONGODB_PASSWORD"] + "@" +
                                         os.environ["MONGODB_CLUSTER"],
                                         alias="default")
            self.mongodbConnection = True
            return True
        except:
            return False

    def setSchema(self):
        try:
            self.schema = schema
            return True
        except:
            return False

    def setGraphqlClient(self):
        try:
            if "pytest" in sys.modules:
                self.gqlclient = TestClient(self.schema)
                self.graphqlConnection = True
                return True
            else:
                self.gqlclient = Client(transport=RequestsHTTPTransport(
                    url='http://localhost:5000/graphql'),
                                        schema=self.schema)
                self.graphqlConnection = True
                return True
        except:
            return False

    def addCategory(self, payload):
        try:
            value = payload["value"]
            label = payload["label"]
            color = payload["color"]
            status, output = self.getCategoryByValue(value)
            if len(output) == 0:
                model = Category(value=value, label=label, color=color)
                model.save()
                modelId = model.id
            else:
                modelId = output[0]["node"]["id"]
            return True, modelId
        except:
            return False

    def addService(self, payload):
        try:
            value = payload["value"]
            status, output = self.getServiceByValue(value)
            if len(output) == 0:
                model = Service(value=value)
                model.save()
                modelId = model.id

            else:
                modelId = output[0]["node"]["id"]
            return True, modelId
        except:
            return False

    def getAllCategories(self):
        try:
            query = gql('''
            {
                allCategories{
                    edges {
                        node {
                            id,
                            value,
                            label,
                            color
                        }
                    }
                }
            }
            ''')
            output = self.gqlclient.execute(query)
            return True, output["data"]["allCategories"]["edges"]
        except:
            return False

    def getCategoryByValue(self, value):
        try:
            query = gql('''
            query ($value: String!) {
                allCategories (value: $value) {
                    edges {
                        node {
                            id,
                            value,
                            label
                        }
                    }
                }
            }
            ''')
            variables = {"value": value}
            if "pytest" in sys.modules:
                output = self.gqlclient.execute(query, variables=variables)
            else:
                output = self.gqlclient.execute(
                    query, variable_values=json.dumps(variables))
            return True, output["data"]["allCategories"]["edges"]
        except:
            return False

    def getServiceByValue(self, value):
        try:
            query = gql('''
            query ($value: String!) {
                allServices (value: $value) {
                    edges {
                        node {
                            id,
                            value,
                            label
                        }
                    }
                }
            }
            ''')
            variables = {"value": value}
            if "pytest" in sys.modules:
                output = self.gqlclient.execute(query, variables=variables)
            else:
                output = self.gqlclient.execute(
                    query, variable_values=json.dumps(variables))
            return True, output["data"]["allServices"]["edges"]
        except:
            return False

    def getApplicationByValue(self, value):
        try:
            query = gql('''
            query ($value: String!) {
                allApplications (value: $value) {
                    edges {
                        node {
                            id,
                            value,
                            label
                        }
                    }
                }
            }
            ''')
            variables = {"value": value}
            if "pytest" in sys.modules:
                output = self.gqlclient.execute(query, variables=variables)
            else:
                output = self.gqlclient.execute(
                    query, variable_values=json.dumps(variables))
            return True, output["data"]["allApplications"]["edges"]
        except:
            return False

    def getCourseBySlug(self, slug):
        try:
            query = gql('''
            query ($slug: String!) {
                allCourses (slug: $slug) {
                    edges {
                        node {
                            id
                            activeStep
                            description
                            length
                            slug
                            title
                            totalSteps
                            category {
                                label
                                value
                            }
                        }
                    }
                }
            }
            ''')
            variables = {"slug": slug}
            if "pytest" in sys.modules:
                output = self.gqlclient.execute(query, variables=variables)
            else:
                output = self.gqlclient.execute(
                    query, variable_values=json.dumps(variables))
            return True, output
        except:
            return False

    def getAllCourses(self):
        try:
            query = gql('''
            {
                allCourses{
                    edges {
                        node {
                            id
                            title
                            activeStep
                            description
                            length
                            slug
                            totalSteps
                        }
                    }
                }
            }
            ''')
            output = self.gqlclient.execute(query)
            print(output)
            return True, output
        except:
            return False

    def mutate(self):
        try:
            query = gql('''
            mutation CreateCourse($title: String!, $activeStep: Int!, $description: String!, $length: Int!, $slug: String!, $totalSteps: Int!) {
                createCourse(title: $title, activeStep: $activeStep, description: $description, length: $length, slug: $slug, totalSteps: $totalSteps) {
                    course {
                        title
                        activeStep
                        description
                        length
                        slug
                        totalSteps
                    }
                }
            }
            ''')
            variables = {
                "title": "Blue Team - Security Engineer IC1",
                "activeStep": 0,
                "description": "This challenge is hard!",
                "length": 50,
                "slug": "this-is-slug",
                "totalSteps": 0
            }
        except:
            return False

    def addCourse(self, course_payload):
        title = course_payload["title"]
        activeStep = course_payload["activeStep"]
        description = course_payload["description"]
        length = course_payload["length"]
        slug = course_payload["slug"]
        totalSteps = course_payload["totalSteps"]
        category = course_payload["category"]
        try:
            status, output = self.getCourseBySlug(slug)
            if "pytest" in sys.modules:
                if len(output["data"]["allCourses"]["edges"]) == 0:
                    xcategory = Category(id=0,
                                         value="red_team",
                                         label="Red team",
                                         color="#2196f3")
                    category = xcategory.save()
                    print("CATEGORY:", category)
                    course = Course(title=title,
                                    activeStep=activeStep,
                                    description=description,
                                    length=length,
                                    slug=slug,
                                    totalSteps=totalSteps,
                                    category=category)
                    course.save()
            else:
                xcategory = Category(label="Red team", value="red_team")
                category = xcategory.save()
                print("CATEGORY:", category)
                course = Course(title=title,
                                activeStep=activeStep,
                                description=description,
                                length=length,
                                slug=slug,
                                totalSteps=totalSteps,
                                category=category)
                course.save()
            return True, "good"
        except:
            return False