예제 #1
0
def get_authorization_code(client, auth_url, redirect_url, scope, state,
                           challenge, port):
    #pylint: disable=unused-variable
    (auth_req_uri, headers,
     body) = client.prepare_authorization_request(authorization_url=auth_url,
                                                  redirect_url=redirect_url,
                                                  scope=scope,
                                                  state=state,
                                                  code_challenge=challenge,
                                                  code_challenge_method="S256")
    click.echo("Opening {uri}".format(uri=auth_req_uri))

    with HTTPServer(("", port), SingleRequestHandler) as httpd:
        webbrowser.open_new(auth_req_uri)
        click.echo(
            "Listening for OAuth authorization callback at {uri}".format(
                uri=redirect_url))
        httpd.handle_request()

    if not global_request_path:
        error_and_quit(
            "No path parameters were returned to the callback at {uri}".format(
                uri=redirect_url))
    # This is a kludge because the parsing library expects https callbacks
    # We should probably set it up using https
    full_redirect_url = "https://localhost:{port}/{path}".format(
        port=port, path=global_request_path)
    try:
        authorization_code_response = \
            client.parse_request_uri_response(full_redirect_url, state=state)
    except OAuth2Error as err:
        error_and_quit(
            "OAuth Token Request error {error}".format(error=err.description))
    return authorization_code_response
예제 #2
0
def check_and_refresh_access_token(hostname, access_token, refresh_token):
    now = datetime.now(tz=UTC)
    # If we can't decode an expiration time, this will be expired by default.
    expiration_time = now
    try:
        # This token has already been verified and we are just parsing it.
        # If it has been tampered with, it will be rejected on the server side.
        # This avoids having to fetch the public key from the issuer and perform
        # an unnecessary signature verification.
        decoded = jwt.decode(access_token, options={"verify_signature": False})
        expiration_time = datetime.fromtimestamp(decoded['exp'], tz=UTC)
    except PyJWTError as err:
        error_and_quit(err)

    if expiration_time > now:
        # The access token is fine. Just return it.
        return access_token, refresh_token, False

    if not refresh_token:
        error_and_quit(
            "OAuth access token expired on {expiration_time}.".format(
                expiration_time=expiration_time))

    # Try to refresh using the refresh token
    click.echo(
        "Attempting to refresh OAuth access token that expired on {expiration_time}"
        .format(expiration_time=expiration_time))
    oauth_response = send_refresh_token_request(hostname, refresh_token)
    fresh_access_token, fresh_refresh_token = get_tokens_from_response(
        oauth_response)
    return fresh_access_token, fresh_refresh_token, True
예제 #3
0
 def decorator(*args, **kwargs):
     config = DatabricksConfig.fetch_from_fs()
     if not config.is_valid:
         error_and_quit(
             ('You haven\'t configured the CLI yet! '
              'Please configure by entering `{} configure`').format(
                  sys.argv[0]))
     return function(*args, **kwargs)
예제 #4
0
def _validate_pipeline_id(pipeline_id):
    """
    Checks if the pipeline_id only contain -, _ and alphanumeric characters
    """
    if len(pipeline_id) == 0:
        error_and_quit(u'Empty pipeline id provided')
    if not set(pipeline_id) <= PIPELINE_ID_PERMITTED_CHARACTERS:
        message = u'Pipeline id {} has invalid character(s)\n'.format(pipeline_id)
        message += u'Valid characters are: _ - a-z A-Z 0-9'
        error_and_quit(message)
예제 #5
0
def _validate_pipeline_id(pipeline_id):
    """
    Checks if the pipeline ID is not empty and contains only hyphen (-),
    underscore (_), and alphanumeric characters.
    """
    if pipeline_id is None or len(pipeline_id) == 0:
        error_and_quit(u'Empty pipeline ID provided')
    if not set(pipeline_id) <= PIPELINE_ID_PERMITTED_CHARACTERS:
        message = u'Pipeline ID {} has invalid character(s)\n'.format(pipeline_id)
        message += u'Valid characters are: _ - a-z A-Z 0-9'
        error_and_quit(message)
예제 #6
0
 def validate(self):
     """
     Checks that the path is a proper DbfsPath. it must have a prefix of
     "dbfs:" and must be an absolute path.
     """
     if self.absolute_path.startswith('dbfs://'):
         error_and_quit(('The path {} cannot start with dbfs://. '
                        'It must start with dbfs:/').format(repr(self)))
     if not self.is_absolute_path:
         error_and_quit('The path {} must start with "{}"'.format(
             repr(self), repr(DbfsPath('dbfs:/'))))
예제 #7
0
 def cp(self, recursive, overwrite, src, dst, headers=None):
     if not DbfsPath.is_valid(src) and DbfsPath.is_valid(dst):
         if not os.path.exists(src):
             error_and_quit('The local file {} does not exist.'.format(src))
         if not recursive:
             if os.path.isdir(src):
                 error_and_quit(
                     ('The local file {} is a directory. You must provide --recursive')
                     .format(src))
             self._copy_to_dbfs_non_recursive(src, DbfsPath(dst), overwrite, headers=headers)
         else:
             if not os.path.isdir(src):
                 self._copy_to_dbfs_non_recursive(src, DbfsPath(dst), overwrite, headers=headers)
                 return
             self._copy_to_dbfs_recursive(src, DbfsPath(dst), overwrite, headers=headers)
     # Copy from DBFS in this case
     elif DbfsPath.is_valid(src) and not DbfsPath.is_valid(dst):
         if not recursive:
             self._copy_from_dbfs_non_recursive(DbfsPath(src), dst, overwrite, headers=headers)
         else:
             dbfs_path_src = DbfsPath(src)
             if not self.get_status(dbfs_path_src, headers=headers).is_dir:
                 self._copy_from_dbfs_non_recursive(dbfs_path_src, dst, overwrite,
                                                    headers=headers)
             self._copy_from_dbfs_recursive(dbfs_path_src, dst, overwrite, headers=headers)
     elif not DbfsPath.is_valid(src) and not DbfsPath.is_valid(dst):
         error_and_quit('Both paths provided are from your local filesystem. '
                        'To use this utility, one of the src or dst must be prefixed '
                        'with dbfs:/')
     elif DbfsPath.is_valid(src) and DbfsPath.is_valid(dst):
         error_and_quit('Both paths provided are from the DBFS filesystem. '
                        'To copy between the DBFS filesystem, you currently must copy the '
                        'file from DBFS to your local filesystem and then back.')
     else:
         assert False, 'not reached'
예제 #8
0
def ls_cli(l, absolute, dbfs_path): #  NOQA
    """
    List files in DBFS.
    """
    if len(dbfs_path) == 0:
        dbfs_path = DbfsPath('dbfs:/')
    elif len(dbfs_path) == 1:
        dbfs_path = dbfs_path[0]
    else:
        error_and_quit('ls can take a maximum of one path.')
    files = list_files(dbfs_path)
    table = tabulate([f.to_row(is_long_form=l, is_absolute=absolute) for f in files],
                     tablefmt='plain')
    click.echo(table)
예제 #9
0
 def get_file(self, dbfs_path, dst_path, overwrite, headers=None):
     if os.path.exists(dst_path) and not overwrite:
         raise LocalFileExistsException('{} exists already.'.format(dst_path))
     file_info = self.get_status(dbfs_path, headers=headers)
     if file_info.is_dir:
         error_and_quit(('The dbfs file {} is a directory.').format(repr(dbfs_path)))
     length = file_info.file_size
     offset = 0
     with open(dst_path, 'wb') as local_file:
         while offset < length:
             response = self.read(dbfs_path, offset, headers=headers)
             bytes_read = response['bytes_read']
             data = response['data']
             offset += bytes_read
             local_file.write(b64decode(data))
예제 #10
0
def _read_spec(src):
    """
    Reads the spec at src as a JSON if no file extension is provided, or if in the extension format
    if the format is supported.
    """
    extension = os.path.splitext(src)[1]
    if extension.lower() == '.json':
        try:
            with open(src, 'r') as f:
                data = f.read()
            return json.loads(data)
        except json_parse_exception as e:
            error_and_quit("Invalid JSON provided in spec\n{}".format(e))
    else:
        raise RuntimeError('The provided file extension for the spec is not supported')
예제 #11
0
def get_file(dbfs_path, dst_path, overwrite):
    if os.path.exists(dst_path) and not overwrite:
        raise LocalFileExistsException()
    dbfs_api = get_dbfs_client()
    file_info = get_status(dbfs_path)
    if file_info.is_dir:
        error_and_quit(('The dbfs file {} is a directory.').format(repr(dbfs_path)))
    length = file_info.file_size
    offset = 0
    with open(dst_path, 'wb') as local_file:
        while offset < length:
            response = dbfs_api.read(dbfs_path.absolute_path, offset, BUFFER_SIZE_BYTES)
            bytes_read = response['bytes_read']
            data = response['data']
            offset += bytes_read
            local_file.write(b64decode(data))
예제 #12
0
def cp_cli(api_client, recursive, overwrite, src, dst):
    """
    Copy files to and from DBFS.

    Note that this function will fail if the src and dst are both on the local filesystem
    or if they are both DBFS paths.

    For non-recursive copies, if the dst is a directory, the file will be placed inside the
    directory. For example ``dbfs cp dbfs:/apple.txt .`` will create a file at `./apple.txt`.

    For recursive copies, files inside of the src directory will be copied inside the dst directory
    with the same name. If the dst path does not exist, a directory will be created. For example
    ``dbfs cp -r dbfs:/foo foo`` will create a directory foo and place the files ``dbfs:/foo/a`` at
    ``foo/a``. If ``foo/a`` already exists, the file will not be overriden unless the --overwrite
    flag is provided -- however, dbfs cp --recursive will continue to try and copy other files.
    """
    # Copy to DBFS in this case
    dbfs_api = DbfsApi(api_client)
    if not DbfsPath.is_valid(src) and DbfsPath.is_valid(dst):
        if not os.path.exists(src):
            error_and_quit('The local file {} does not exist.'.format(src))
        if not recursive:
            if os.path.isdir(src):
                error_and_quit((
                    'The local file {} is a directory. You must provide --recursive'
                ).format(src))
            copy_to_dbfs_non_recursive(dbfs_api, src, DbfsPath(dst), overwrite)
        else:
            if not os.path.isdir(src):
                copy_to_dbfs_non_recursive(dbfs_api, src, DbfsPath(dst),
                                           overwrite)
                return
            copy_to_dbfs_recursive(dbfs_api, src, DbfsPath(dst), overwrite)
    # Copy from DBFS in this case
    elif DbfsPath.is_valid(src) and not DbfsPath.is_valid(dst):
        if not recursive:
            copy_from_dbfs_non_recursive(dbfs_api, DbfsPath(src), dst,
                                         overwrite)
        else:
            dbfs_path_src = DbfsPath(src)
            if not dbfs_api.get_status(dbfs_path_src).is_dir:
                copy_from_dbfs_non_recursive(dbfs_api, dbfs_path_src, dst,
                                             overwrite)
            copy_from_dbfs_recursive(dbfs_api, dbfs_path_src, dst, overwrite)
    elif not DbfsPath.is_valid(src) and not DbfsPath.is_valid(dst):
        error_and_quit(
            'Both paths provided are from your local filesystem. '
            'To use this utility, one of the src or dst must be prefixed '
            'with dbfs:/')
    elif DbfsPath.is_valid(src) and DbfsPath.is_valid(dst):
        error_and_quit(
            'Both paths provided are from the DBFS filesystem. '
            'To copy between the DBFS filesystem, you currently must copy the '
            'file from DBFS to your local filesystem and then back.')
    else:
        assert False, 'not reached'
예제 #13
0
def _read_settings(src):
    """
    Reads the settings at src as a JSON If the file has JSON extension or
    if no file extension is provided. Other file extensions and formats are
    not supported.
    """
    extension = os.path.splitext(src)[1]
    if extension.lower() == '.json' or (not extension):
        try:
            with open(src, 'r') as f:
                data = f.read()
            return json.loads(data)
        except json_parse_exception as e:
            error_and_quit("Invalid JSON provided in settings\n{}.".format(e))
    else:
        raise ValueError(
            'The provided file extension for the settings is not supported. ' +
            'Only JSON files are supported.')
예제 #14
0
def get_file_contents(dbfs_service: DbfsService,
                      dbfs_path: Text,
                      headers=None):
    abs_path = f"dbfs:{dbfs_path}"
    json = dbfs_service.get_status(abs_path, headers=headers)
    file_info = FileInfo.from_json(json)
    if file_info.is_dir:
        error_and_quit('The dbfs file {} is a directory.'.format(
            repr(abs_path)))
    length = file_info.file_size
    offset = 0
    output = io.StringIO()
    while offset < length:
        response = dbfs_service.read(abs_path,
                                     offset,
                                     BUFFER_SIZE_BYTES,
                                     headers=headers)
        bytes_read = response['bytes_read']
        data = response['data']
        offset += bytes_read
        output.write(b64decode(data).decode("utf-8"))
    return output.getvalue()
예제 #15
0
def get_tokens(hostname, scope=None):
    idp_url = get_idp_url(hostname)
    oauth_config = fetch_well_known_config(idp_url)
    # We are going to override oauth_config["authorization_endpoint"] use the
    # /oidc redirector on the hostname, which may inject additional parameters.
    auth_url = "{}oidc/v1/authorize".format(hostname)
    state = token_urlsafe(16)
    (verifier, challenge) = get_challenge()
    client = get_client()
    redirect_url = get_redirect_url()
    try:
        auth_response = get_authorization_code(client, auth_url, redirect_url,
                                               scope, state, challenge,
                                               REDIRECT_PORT)
    except OAuth2Error as err:
        error_and_quit(
            "OAuth Authorization Error: {error}".format(error=err.description))

    token_request_url = oauth_config["token_endpoint"]
    code = auth_response['code']
    oauth_response = \
        send_auth_code_token_request(client, token_request_url, redirect_url, code, verifier)
    return get_tokens_from_response(oauth_response)
예제 #16
0
 def cp(self, recursive, overwrite, src, dst, headers=None):
     if not DbfsPath.is_valid(src) and DbfsPath.is_valid(dst):
         if not os.path.exists(src):
             error_and_quit('The local file {} does not exist.'.format(src))
         if not recursive:
             if os.path.isdir(src):
                 error_and_quit((
                     'The local file {} is a directory. You must provide --recursive'
                 ).format(src))
             self._copy_to_dbfs_non_recursive(src,
                                              DbfsPath(dst),
                                              overwrite,
                                              headers=headers)
         else:
             if not os.path.isdir(src):
                 self._copy_to_dbfs_non_recursive(src,
                                                  DbfsPath(dst),
                                                  overwrite,
                                                  headers=headers)
                 return
             self._copy_to_dbfs_recursive(src,
                                          DbfsPath(dst),
                                          overwrite,
                                          headers=headers)
     # Copy from DBFS in this case
     elif DbfsPath.is_valid(src) and not DbfsPath.is_valid(dst):
         if not recursive:
             self._copy_from_dbfs_non_recursive(DbfsPath(src),
                                                dst,
                                                overwrite,
                                                headers=headers)
         else:
             dbfs_path_src = DbfsPath(src)
             if not self.get_status(dbfs_path_src, headers=headers).is_dir:
                 self._copy_from_dbfs_non_recursive(dbfs_path_src,
                                                    dst,
                                                    overwrite,
                                                    headers=headers)
             self._copy_from_dbfs_recursive(dbfs_path_src,
                                            dst,
                                            overwrite,
                                            headers=headers)
     elif not DbfsPath.is_valid(src) and not DbfsPath.is_valid(dst):
         error_and_quit(
             'Both paths provided are from your local filesystem. '
             'To use this utility, one of the src or dst must be prefixed '
             'with dbfs:/')
     elif DbfsPath.is_valid(src) and DbfsPath.is_valid(dst):
         with TempDir() as temp_dir:
             # Always copy to <temp_dir>/temp since this will work no matter if it's a
             # recursive or a non-recursive copy.
             temp_path = temp_dir.path('temp')
             self.cp(recursive, True, src, temp_path)
             self.cp(recursive, overwrite, temp_path, dst)
     else:
         assert False, 'not reached'
예제 #17
0
def _verify_and_translate_options(string_value, binary_file):
    """
    Translates options into actual parameters for API call.
    Return tuple with two values representing (string_value, bytes_value).
    """
    if string_value and binary_file:
        error_and_quit("At most one of {} should be provided.".format(
            ['string-value', 'binary-file']))

    elif string_value is None and binary_file is None:
        prompt = '# Do not edit the above line. Everything below it will be ignored.\n' + \
            '# Please input your secret value above the line. Text will be stored in\n' + \
            '# UTF-8 (MB4) form and any trailing new line will be stripped.\n' + \
            '# Exit without saving will abort writing secret.'

        # underlying edit function made sure using a temporary file for editing
        content = click.edit('\n\n' + DASH_MARKER + prompt)
        # return None means editor is closed without changes
        if content is None:
            error_and_quit(
                'No changes made, write secret aborted.'
                ' Please follow the instruction to input secret value.')

        elif DASH_MARKER not in content:
            error_and_quit(
                'Please DO NOT edit the line with dashes. Write secret aborted.'
            )

        return content.split(DASH_MARKER, 1)[0].rstrip('\n'), None

    elif string_value is not None:
        return string_value, None

    elif binary_file is not None:
        with open(binary_file, 'rb') as f:
            binary_content = f.read()

        base64_bytes = base64.b64encode(binary_content)
        base64_str = base64_bytes.decode('utf-8')

        return None, base64_str
예제 #18
0
def fetch_well_known_config(idp_url):
    known_config_url = "{idp_url}/.well-known/oauth-authorization-server".format(
        idp_url=idp_url)
    try:
        response = requests.request(method="GET", url=known_config_url)
    except RequestException:
        error_and_quit("Unable to fetch OAuth configuration from {idp_url}.\n"
                       "Verify it is a valid workspace URL and that OAuth is "
                       "enabled on this account.".format(idp_url=idp_url))

    if response.status_code != 200:
        error_and_quit("Received status {status} OAuth configuration from "
                       "{idp_url}.\n Verify it is a valid workspace URL and "
                       "that OAuth is enabled on this account.".format(
                           status=response.status_code, idp_url=idp_url))
    try:
        return json.loads(response.text)
    except json.decoder.JSONDecodeError:
        error_and_quit("Unable to decode OAuth configuration from {idp_url}.\n"
                       "Verify it is a valid workspace URL and that OAuth is "
                       "enabled on this account.".format(idp_url=idp_url))
예제 #19
0
def _validate_pipeline_id(pipeline_id):
    """
    Checks if the pipeline ID is not empty.
    """
    if pipeline_id is None or len(pipeline_id) == 0:
        error_and_quit(u'Empty pipeline ID provided')