Пример #1
0
 def assign_user_entitlements(self, current_user_ids, error_logger, user_log_file='users.log'):
     """
     assign user entitlements to allow cluster create, job create, sql analytics etc
     :param user_log_file:
     :param current_user_ids: dict of the userName to id mapping of the new env
     :return:
     """
     user_log = self.get_export_dir() + user_log_file
     if not os.path.exists(user_log):
         logging.info("Skipping user entitlement assignment. Logfile does not exist")
         return
     with open(user_log, 'r') as fp:
         # loop through each user in the file
         for line in fp:
             user = json.loads(line)
             userName = user['userName']
             if userName not in current_user_ids:
                 continue
             # add the users entitlements
             user_entitlements = user.get('entitlements', None)
             # get the current registered user id
             user_id = current_user_ids[user['userName']]
             if user_entitlements:
                 entitlements_args = self.assign_entitlements_args(user_entitlements)
                 update_resp = self.patch(f'/preview/scim/v2/Users/{user_id}', entitlements_args)
                 logging_utils.log_reponse_error(error_logger, update_resp)
Пример #2
0
 def import_current_workspace_items(self, artifact_dir='artifacts/'):
     src_dir = self.get_export_dir() + artifact_dir
     error_logger = logging_utils.get_error_logger(
         wmconstants.WM_IMPORT, wmconstants.WORKSPACE_NOTEBOOK_OBJECT,
         self.get_export_dir())
     for root, subdirs, files in self.walk(src_dir):
         # replace the local directory with empty string to get the notebook workspace directory
         nb_dir = '/' + root.replace(src_dir, '')
         upload_dir = nb_dir
         if not nb_dir == '/':
             upload_dir = nb_dir + '/'
         if not self.does_path_exist(upload_dir):
             resp_mkdirs = self.post(WS_MKDIRS, {'path': upload_dir})
             if 'error_code' in resp_mkdirs:
                 logging_utils.log_reponse_error(error_logger, resp_mkdirs)
         for f in files:
             logging.info("Uploading: {0}".format(f))
             # create the local file path to load the DBC file
             local_file_path = os.path.join(root, f)
             # create the ws full file path including filename
             ws_file_path = upload_dir + f
             # generate json args with binary data for notebook to upload to the workspace path
             nb_input_args = self.get_user_import_args(
                 local_file_path, ws_file_path)
             # call import to the workspace
             if self.is_verbose():
                 logging.info("Path: {0}".format(nb_input_args['path']))
             resp_upload = self.post(WS_IMPORT, nb_input_args)
             if 'error_code' in resp_upload:
                 resp_upload['path'] = nb_input_args['path']
                 logging_utils.log_reponse_error(error_logger, resp_upload)
Пример #3
0
    def apply_acl_on_object(self, acl_str, error_logger, checkpoint_key_set):
        """
        apply the acl definition to the workspace object
        object_id comes from the export data which contains '/type/id' format for this key
        the object_id contains the {{/type/object_id}} format which helps craft the api endpoint
        setting acl definitions uses the patch rest api verb
        :param acl_str: the complete string from the logfile. contains object defn and acl lists
        """
        object_acl = json.loads(acl_str)
        # the object_type
        object_type = object_acl.get('object_type', None)
        obj_path = object_acl['path']
        logging.info(f"Working on ACL for path: {obj_path}")

        if not checkpoint_key_set.contains(obj_path):
            # We cannot modify '/Shared' directory's ACL
            if obj_path == "/Shared" and object_type == "directory":
                logging.info(
                    "We cannot modify /Shared directory's ACL. Skipping..")
                checkpoint_key_set.write(obj_path)
                return

            if self.is_user_ws_item(obj_path):
                ws_user = self.get_user(obj_path)
                if not self.does_user_exist(ws_user):
                    logging.info(
                        f"User workspace does not exist: {obj_path}, skipping ACL"
                    )
                    return
            obj_status = self.get(WS_STATUS, {'path': obj_path})
            if logging_utils.log_reponse_error(error_logger, obj_status):
                return
            logging.info("ws-stat: ", obj_status)
            current_obj_id = obj_status.get('object_id', None)
            if not current_obj_id:
                error_logger.error(
                    f'Object id missing from destination workspace: {obj_status}'
                )
                return
            if object_type == 'directory':
                object_id_with_type = f'/directories/{current_obj_id}'
            elif object_type == 'notebook':
                object_id_with_type = f'/notebooks/{current_obj_id}'
            else:
                error_logger.error(
                    f'Object for Workspace ACLs is Undefined: {obj_status}')
                return
            api_path = '/permissions' + object_id_with_type
            acl_list = object_acl.get('access_control_list', None)
            access_control_list = self.build_acl_args(acl_list)
            if access_control_list:
                api_args = {'access_control_list': access_control_list}
                resp = self.patch(api_path, api_args)
                if not logging_utils.log_reponse_error(error_logger, resp):
                    checkpoint_key_set.write(obj_path)
        return
Пример #4
0
    def import_cluster_policies(self,
                                log_file='cluster_policies.log',
                                acl_log_file='acl_cluster_policies.log'):
        policies_log = self.get_export_dir() + log_file
        acl_policies_log = self.get_export_dir() + acl_log_file
        error_logger = logging_utils.get_error_logger(
            wmconstants.WM_IMPORT, wmconstants.CLUSTER_OBJECT,
            self.get_export_dir())
        checkpoint_cluster_policies_set = self._checkpoint_service.get_checkpoint_key_set(
            wmconstants.WM_IMPORT, wmconstants.CLUSTER_OBJECT)
        # create the policies
        if os.path.exists(policies_log):
            with open(policies_log, 'r') as policy_fp:
                for p in policy_fp:
                    policy_conf = json.loads(p)
                    if 'policy_id' in policy_conf and checkpoint_cluster_policies_set.contains(
                            policy_conf['policy_id']):
                        continue
                    # when creating the policy, we only need `name` and `definition` fields
                    create_args = {
                        'name': policy_conf['name'],
                        'definition': policy_conf['definition']
                    }
                    resp = self.post('/policies/clusters/create', create_args)
                    ignore_error_list = ['INVALID_PARAMETER_VALUE']
                    if not logging_utils.log_reponse_error(
                            error_logger,
                            resp,
                            ignore_error_list=ignore_error_list):
                        if 'policy_id' in policy_conf:
                            checkpoint_cluster_policies_set.write(
                                policy_conf['policy_id'])

            # ACLs are created by using the `access_control_list` key
            with open(acl_policies_log, 'r') as acl_fp:
                id_map = self.get_policy_id_by_name_dict()
                for x in acl_fp:
                    p_acl = json.loads(x)
                    if 'object_id' in p_acl and checkpoint_cluster_policies_set.contains(
                            p_acl['object_id']):
                        continue
                    acl_create_args = {
                        'access_control_list':
                        self.build_acl_args(p_acl['access_control_list'])
                    }
                    policy_id = id_map[p_acl['name']]
                    api = f'/permissions/cluster-policies/{policy_id}'
                    resp = self.put(api, acl_create_args)
                    if not logging_utils.log_reponse_error(error_logger, resp):
                        if 'object_id' in p_acl:
                            checkpoint_cluster_policies_set.write(
                                p_acl['object_id'])
        else:
            logging.info('Skipping cluster policies as no log file exists')
Пример #5
0
    def log_table_ddl(self, cid, ec_id, db_name, table_name, metastore_dir,
                      error_logger, has_unicode):
        """
        Log the table DDL to handle large DDL text
        :param cid: cluster id
        :param ec_id: execution context id (rest api 1.2)
        :param db_name: database name
        :param table_name: table name
        :param metastore_dir: metastore export directory name
        :param err_log_path: log for errors
        :param has_unicode: export to a file if this flag is true
        :return: True for success, False for error
        """
        set_ddl_str_cmd = f'ddl_str = spark.sql("show create table {db_name}.{table_name}").collect()[0][0]'
        ddl_str_resp = self.submit_command(cid, ec_id, set_ddl_str_cmd)

        if ddl_str_resp['resultType'] != 'text':
            ddl_str_resp['table'] = '{0}.{1}'.format(db_name, table_name)
            error_logger.error(json.dumps(ddl_str_resp))
            return False
        get_ddl_str_len = 'ddl_len = len(ddl_str); print(ddl_len)'
        len_resp = self.submit_command(cid, ec_id, get_ddl_str_len)
        ddl_len = int(len_resp['data'])
        if ddl_len <= 0:
            len_resp['table'] = '{0}.{1}'.format(db_name, table_name)
            error_logger.error(json.dumps(len_resp) + '\n')
            return False
        # if (len > 2k chars) OR (has unicode chars) then export to file
        table_ddl_path = self.get_export_dir(
        ) + metastore_dir + db_name + '/' + table_name
        if ddl_len > 2048 or has_unicode:
            # create the dbfs tmp path for exports / imports. no-op if exists
            resp = self.post('/dbfs/mkdirs', {'path': '/tmp/migration/'})
            if logging_utils.log_reponse_error(error_logger, resp):
                return False
            # save the ddl to the tmp path on dbfs
            save_ddl_cmd = "with open('/dbfs/tmp/migration/tmp_export_ddl.txt', 'w') as fp: fp.write(ddl_str)"
            save_resp = self.submit_command(cid, ec_id, save_ddl_cmd)
            if logging_utils.log_reponse_error(error_logger, save_resp):
                return False
            # read that data using the dbfs rest endpoint which can handle 2MB of text easily
            read_args = {'path': '/tmp/migration/tmp_export_ddl.txt'}
            read_resp = self.get('/dbfs/read', read_args)
            with open(table_ddl_path, "w") as fp:
                fp.write(
                    base64.b64decode(read_resp.get('data')).decode('utf-8'))
            return True
        else:
            export_ddl_cmd = 'print(ddl_str)'
            ddl_resp = self.submit_command(cid, ec_id, export_ddl_cmd)
            with open(table_ddl_path, "w") as fp:
                fp.write(ddl_resp.get('data'))
            return True
Пример #6
0
    def import_groups(self, group_dir, current_user_ids, error_logger):
        checkpoint_groups_set = self._checkpoint_service.get_checkpoint_key_set(
            wmconstants.WM_IMPORT, wmconstants.GROUP_OBJECT)
        # list all the groups and create groups first
        if not os.path.exists(group_dir):
            logging.info("No groups to import.")
            return
        groups = self.listdir(group_dir)
        for x in groups:
            if not checkpoint_groups_set.contains(x):
                logging.info('Creating group: {0}'.format(x))
                # set the create args displayName property aka group name
                create_args = {
                    "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
                    "displayName": x
                }
                group_resp = self.post('/preview/scim/v2/Groups', create_args)
                if not logging_utils.log_reponse_error(error_logger, group_resp):
                    checkpoint_groups_set.write(x)

        # dict of { group_name : group_id }
        groups = self.listdir(group_dir)
        current_group_ids = self.get_current_group_ids()
        # dict of { old_user_id : email }
        old_user_emails = self.get_old_user_emails()
        for group_name in groups:
            with open(group_dir + group_name, 'r') as fp:
                members = json.loads(fp.read()).get('members', None)
                logging.info(f"Importing group {group_name} :")
                if members:
                    # grab a list of ids to add either groups or users to this current group
                    member_id_list = []
                    for m in members:
                        if self.is_user(m):
                            try:
                                old_email = old_user_emails[m['value']]
                                this_user_id = current_user_ids.get(old_email, '')
                                if not this_user_id:
                                    error_logger.error(f'Unable to find user {old_email} in the new workspace. '
                                                     f'This users email case has changed and needs to be updated with '
                                                     f'the --replace-old-email and --update-new-email options')
                                member_id_list.append(this_user_id)
                            except KeyError:
                                error_logger.error(f"Error adding member {m} to group {group_name}")
                        elif self.is_group(m):
                            this_group_id = current_group_ids.get(m['display'])
                            member_id_list.append(this_group_id)
                        else:
                            logging.info("Skipping service principal members and other identities not within users/groups")
                    add_members_json = self.get_member_args(member_id_list)
                    group_id = current_group_ids[group_name]
                    add_resp = self.patch('/preview/scim/v2/Groups/{0}'.format(group_id), add_members_json)
                    logging_utils.log_reponse_error(error_logger, add_resp)
Пример #7
0
 def get_secret_value(self, scope_name, secret_key, cid, ec_id,
                      error_logger):
     cmd_set_value = f"value = dbutils.secrets.get(scope = '{scope_name}', key = '{secret_key}')"
     cmd_convert_b64 = "import base64; b64_value = base64.b64encode(value.encode('ascii'))"
     cmd_get_b64 = "print(b64_value.decode('ascii'))"
     results_set = self.submit_command(cid, ec_id, cmd_set_value)
     results_convert = self.submit_command(cid, ec_id, cmd_convert_b64)
     results_get = self.submit_command(cid, ec_id, cmd_get_b64)
     if logging_utils.log_reponse_error(error_logger, results_set) \
             or logging_utils.log_reponse_error(error_logger, results_convert) \
             or logging_utils.log_reponse_error(error_logger, results_get):
         return None
     else:
         return results_get.get('data')
Пример #8
0
 def log_all_secrets(self, cluster_name=None, log_dir='secret_scopes/'):
     scopes_dir = self.get_export_dir() + log_dir
     scopes_list = self.get_secret_scopes_list()
     error_logger = logging_utils.get_error_logger(
         wmconstants.WM_EXPORT, wmconstants.SECRET_OBJECT,
         self.get_export_dir())
     os.makedirs(scopes_dir, exist_ok=True)
     start = timer()
     cid = self.start_cluster_by_name(
         cluster_name) if cluster_name else self.launch_cluster()
     time.sleep(5)
     ec_id = self.get_execution_context(cid)
     for scope_json in scopes_list:
         scope_name = scope_json.get('name')
         secrets_list = self.get_secrets(scope_name)
         if logging_utils.log_reponse_error(error_logger, secrets_list):
             continue
         scopes_logfile = scopes_dir + scope_name
         try:
             with open(scopes_logfile, 'w') as fp:
                 for secret_json in secrets_list:
                     secret_name = secret_json.get('key')
                     b64_value = self.get_secret_value(
                         scope_name, secret_name, cid, ec_id, error_logger)
                     s_json = {'name': secret_name, 'value': b64_value}
                     fp.write(json.dumps(s_json) + '\n')
         except ValueError as error:
             if "embedded null byte" in str(error):
                 error_msg = f"{scopes_logfile} has bad name and hence cannot open: {str(error)} Skipping.."
                 logging.error(error_msg)
                 error_logger.error(error_msg)
             else:
                 raise error
Пример #9
0
 def import_instance_profiles(self, log_file='instance_profiles.log'):
     # currently an AWS only operation
     error_logger = logging_utils.get_error_logger(
         wmconstants.WM_IMPORT, wmconstants.INSTANCE_PROFILE_OBJECT,
         self.get_export_dir())
     ip_log = self.get_export_dir() + log_file
     if not os.path.exists(ip_log):
         logging.info("No instance profiles to import.")
         return
     # check current profiles and skip if the profile already exists
     ip_list = self.get('/instance-profiles/list').get(
         'instance_profiles', None)
     if ip_list:
         list_of_profiles = [x['instance_profile_arn'] for x in ip_list]
     else:
         list_of_profiles = []
     import_profiles_count = 0
     with open(ip_log, "r") as fp:
         for line in fp:
             ip_arn = json.loads(line).get('instance_profile_arn', None)
             if ip_arn not in list_of_profiles:
                 print("Importing arn: {0}".format(ip_arn))
                 resp = self.post('/instance-profiles/add',
                                  {'instance_profile_arn': ip_arn})
                 if not logging_utils.log_reponse_error(error_logger, resp):
                     import_profiles_count += 1
             else:
                 logging.info(
                     "Skipping since profile already exists: {0}".format(
                         ip_arn))
     return import_profiles_count
Пример #10
0
 def assign_group_entitlements(self, group_dir, error_logger):
     # assign group role ACLs, which are only available via SCIM apis
     group_ids = self.get_current_group_ids()
     if not os.path.exists(group_dir):
         logging.info("No groups defined. Skipping group entitlement assignment")
         return
     groups = self.listdir(group_dir)
     for group_name in groups:
         with open(group_dir + group_name, 'r') as fp:
             group_data = json.loads(fp.read())
             entitlements = group_data.get('entitlements', None)
             if entitlements:
                 g_id = group_ids[group_name]
                 update_entitlements = self.assign_entitlements_args(entitlements)
                 up_resp = self.patch(f'/preview/scim/v2/Groups/{g_id}', update_entitlements)
                 logging_utils.log_reponse_error(error_logger, up_resp)
Пример #11
0
 def assign_user_roles(self, current_user_ids, error_logger, user_log_file='users.log'):
     """
     assign user roles that are missing after adding group assignment
     Note: There is a limitation in the exposed API. If a user is assigned a role permission & the permission
     is granted via a group, we can't distinguish the difference. Only group assignment will be migrated.
     :param user_log_file: logfile of all user properties
     :param current_user_ids: dict of the userName to id mapping of the new env
     :return:
     """
     user_log = self.get_export_dir() + user_log_file
     if not os.path.exists(user_log):
         logging.info("Skipping user entitlement assignment. Logfile does not exist")
         return
     # keys to filter from the user log to get the user / role mapping
     old_role_keys = ('userName', 'roles')
     cur_role_keys = ('schemas', 'userName', 'entitlements', 'roles', 'groups')
     # get current user id of the new environment, k,v = email, id
     with open(user_log, 'r') as fp:
         # loop through each user in the file
         for line in fp:
             user = json.loads(line)
             user_roles = {k: user[k] for k in old_role_keys if k in user}
             userName = user['userName']
             if userName not in current_user_ids:
                 continue
             # get the current registered user id
             user_id = current_user_ids[user['userName']]
             # get the current users settings
             cur_user = self.get('/preview/scim/v2/Users/{0}'.format(user_id))
             # get the current users IAM roles
             current_roles = cur_user.get('roles', None)
             if current_roles:
                 cur_role_values = set([x['value'] for x in current_roles])
             else:
                 cur_role_values = set()
             # get the users saved IAM roles from the export
             saved_roles = user_roles.get('roles', None)
             if saved_roles:
                 saved_role_values = set([y['value'] for y in saved_roles])
             else:
                 saved_role_values = set()
             roles_needed = list(saved_role_values - cur_role_values)
             if roles_needed:
                 # get the json to add the roles to the user profile
                 patch_roles = self.add_roles_arg(roles_needed)
                 update_resp = self.patch(f'/preview/scim/v2/Users/{user_id}', patch_roles)
                 logging_utils.log_reponse_error(error_logger, update_resp)
Пример #12
0
 def import_instance_pools(self, log_file='instance_pools.log'):
     pool_log = self.get_export_dir() + log_file
     error_logger = logging_utils.get_error_logger(
         wmconstants.WM_IMPORT, wmconstants.INSTANCE_POOL_OBJECT,
         self.get_export_dir())
     if not os.path.exists(pool_log):
         logging.info("No instance pools to import.")
         return
     with open(pool_log, 'r') as fp:
         for line in fp:
             pool_conf = json.loads(line)
             pool_resp = self.post('/instance-pools/create', pool_conf)
             ignore_error_list = ['INVALID_PARAMETER_VALUE']
             logging_utils.log_reponse_error(
                 error_logger,
                 pool_resp,
                 ignore_error_list=ignore_error_list)
Пример #13
0
 def _import_users_helper(self, user_data, create_keys, checkpoint_set, error_logger):
     user = json.loads(user_data)
     user_name = user['userName']
     if not checkpoint_set.contains(user_name):
         logging.info("Creating user: {0}".format(user_name))
         user_create = {k: user[k] for k in create_keys if k in user}
         create_resp = self.post('/preview/scim/v2/Users', user_create)
         if not logging_utils.log_reponse_error(error_logger, create_resp):
             checkpoint_set.write(user_name)
Пример #14
0
 def _acl_log_helper(json_data):
     data = json.loads(json_data)
     obj_id = data.get('object_id', None)
     api_endpoint = '/permissions/{0}/{1}'.format(artifact_type, obj_id)
     acl_resp = self.get(api_endpoint)
     acl_resp['path'] = data.get('path')
     if logging_utils.log_reponse_error(error_logger, acl_resp):
         return
     acl_resp.pop('http_status_code')
     writer.write(json.dumps(acl_resp) + '\n')
Пример #15
0
 def download_notebook_helper(self,
                              notebook_data,
                              checkpoint_notebook_set,
                              error_logger,
                              export_dir='artifacts/'):
     """
     Helper function to download an individual notebook, or log the failure in the failure logfile
     :param notebook_path: an individual notebook path
     :param export_dir: directory to store all notebooks
     :return: return the notebook path that's successfully downloaded
     """
     notebook_path = json.loads(notebook_data).get('path',
                                                   None).rstrip('\n')
     if checkpoint_notebook_set.contains(notebook_path):
         return {'path': notebook_path}
     get_args = {'path': notebook_path, 'format': self.get_file_format()}
     if self.is_verbose():
         logging.info("Downloading: {0}".format(get_args['path']))
     resp = self.get(WS_EXPORT, get_args)
     if resp.get('error', None):
         resp['path'] = notebook_path
         logging_utils.log_reponse_error(error_logger, resp)
         return resp
     if resp.get('error_code', None):
         resp['path'] = notebook_path
         logging_utils.log_reponse_error(error_logger, resp)
         return resp
     nb_path = os.path.dirname(notebook_path)
     if nb_path != '/':
         # path is NOT empty, remove the trailing slash from export_dir
         save_path = export_dir[:-1] + nb_path + '/'
     else:
         save_path = export_dir
     save_filename = save_path + os.path.basename(
         notebook_path) + '.' + resp.get('file_type')
     # If the local path doesn't exist,we create it before we save the contents
     if not os.path.exists(save_path) and save_path:
         os.makedirs(save_path, exist_ok=True)
     with open(save_filename, "wb") as f:
         f.write(base64.b64decode(resp['content']))
     checkpoint_notebook_set.write(notebook_path)
     return {'path': notebook_path}
Пример #16
0
 def _file_upload_helper(f):
     logging.info("Uploading: {0}".format(f))
     # create the local file path to load the DBC file
     local_file_path = os.path.join(root, f)
     # create the ws full file path including filename
     ws_file_path = upload_dir + f
     if checkpoint_notebook_set.contains(ws_file_path):
         return
     # generate json args with binary data for notebook to upload to the workspace path
     nb_input_args = self.get_user_import_args(
         local_file_path, ws_file_path)
     # call import to the workspace
     if self.is_verbose():
         logging.info("Path: {0}".format(nb_input_args['path']))
     resp_upload = self.post(WS_IMPORT, nb_input_args)
     if 'error_code' in resp_upload:
         resp_upload['path'] = ws_file_path
         logging.info(f'Error uploading file: {ws_file_path}')
         logging_utils.log_reponse_error(error_logger, resp_upload)
     else:
         checkpoint_notebook_set.write(ws_file_path)
Пример #17
0
 def log_all_secrets_acls(self, log_name='secret_scopes_acls.log'):
     acls_file = self.get_export_dir() + log_name
     error_logger = logging_utils.get_error_logger(
         wmconstants.WM_EXPORT, wmconstants.SECRET_OBJECT,
         self.get_export_dir())
     scopes_list = self.get_secret_scopes_list()
     with open(acls_file, 'w') as fp:
         for scope_json in scopes_list:
             scope_name = scope_json.get('name', None)
             resp = self.get('/secrets/acls/list', {'scope': scope_name})
             if logging_utils.log_reponse_error(error_logger, resp):
                 return
             else:
                 resp['scope_name'] = scope_name
                 fp.write(json.dumps(resp) + '\n')
Пример #18
0
 def _put_mlflow_experiment_acl(self, acl_str, experiment_id_map,
                                checkpoint_key_set, error_logger):
     acl_obj = json.loads(acl_str)
     experiment_id = acl_obj["object_id"].split("/")[-1]
     if checkpoint_key_set.contains(experiment_id):
         return
     if experiment_id not in experiment_id_map:
         error_msg = f"experiment_id: {experiment_id} does not exist in mlflow_experiments_id_map.log. Skipping... but logged to error log file."
         error_logger.error(error_msg)
         return
     new_experiment_id = experiment_id_map[experiment_id]
     acl_create_args = {
         'access_control_list':
         self.build_acl_args(acl_obj['access_control_list'], True)
     }
     resp = self.put(f'/permissions/experiments/{new_experiment_id}',
                     acl_create_args)
     if not logging_utils.log_reponse_error(error_logger, resp):
         checkpoint_key_set.write(experiment_id)
Пример #19
0
 def log_job_configs(self,
                     users_list=[],
                     log_file='jobs.log',
                     acl_file='acl_jobs.log'):
     """
     log all job configs and the ACLs for each job
     :param users_list: a list of users / emails to filter the results upon (optional for group exports)
     :param log_file: log file to store job configs as json entries per line
     :param acl_file: log file to store job ACLs
     :return:
     """
     jobs_log = self.get_export_dir() + log_file
     acl_jobs_log = self.get_export_dir() + acl_file
     error_logger = logging_utils.get_error_logger(wmconstants.WM_EXPORT,
                                                   wmconstants.JOB_OBJECT,
                                                   self.get_export_dir())
     # pinned by cluster_user is a flag per cluster
     jl_full = self.get_jobs_list(False)
     if users_list:
         # filter the jobs list to only contain users that exist within this list
         jl = list(
             filter(lambda x: x.get('creator_user_name', '') in users_list,
                    jl_full))
     else:
         jl = jl_full
     with open(jobs_log, "w") as log_fp, open(acl_jobs_log, 'w') as acl_fp:
         for x in jl:
             job_id = x['job_id']
             new_job_name = x['settings']['name'] + ':::' + str(job_id)
             # grab the settings obj
             job_settings = x['settings']
             # update the job name
             job_settings['name'] = new_job_name
             # reset the original struct with the new settings
             x['settings'] = job_settings
             log_fp.write(json.dumps(x) + '\n')
             job_perms = self.get(f'/preview/permissions/jobs/{job_id}')
             if not logging_utils.log_reponse_error(error_logger,
                                                    job_perms):
                 job_perms['job_name'] = new_job_name
                 acl_fp.write(json.dumps(job_perms) + '\n')
Пример #20
0
 def update_imported_job_names(self, error_logger,
                               checkpoint_job_configs_set):
     # loop through and update the job names to remove the custom delimiter + job_id suffix
     current_jobs_list = self.get_jobs_list()
     for job in current_jobs_list:
         job_id = job['job_id']
         job_name = job['settings']['name']
         # job name was set to `old_job_name:::{job_id}` to support duplicate job names
         # we need to parse the old job name and update the current jobs
         if checkpoint_job_configs_set.contains(job_name):
             continue
         old_job_name = job_name.split(':::')[0]
         new_settings = {'name': old_job_name}
         update_args = {'job_id': job_id, 'new_settings': new_settings}
         logging.info(f'Updating job name: {update_args}')
         resp = self.post('/jobs/update', update_args)
         if not logging_utils.log_reponse_error(error_logger, resp):
             checkpoint_job_configs_set.write(job_name)
         else:
             raise RuntimeError(
                 "Import job has failed. Refer to the previous log messages to investigate."
             )
Пример #21
0
 def _get_mlflow_experiment_acls(self, acl_log_file_writer, experiment_str,
                                 checkpoint_key_set, error_logger):
     experiment_obj = json.loads(experiment_str)
     experiment_id = experiment_obj.get('experiment_id')
     experiment_type = experiment_obj.get('tags').get(
         'mlflow.experimentType')
     if checkpoint_key_set.contains(experiment_id):
         return
     if experiment_type != "MLFLOW_EXPERIMENT":
         logging.info(
             f"Experiment {experiment_id}'s experimentType is {experiment_type}. Only "
             "MLFLOW_EXPERIMENT type's permissions are exported. Skipping..."
         )
         return
     logging.info(f"Exporting ACLs for experiment_id: {experiment_id}.")
     perms = self.get(f"/permissions/experiments/{experiment_id}",
                      do_not_throw=True)
     if not logging_utils.log_reponse_error(error_logger, perms):
         acl_log_file_writer.write(json.dumps(perms) + '\n')
         checkpoint_key_set.write(experiment_id)
         logging.info(
             f"Successfully exported ACLs for experiment_id: {experiment_id}."
         )
Пример #22
0
    def get_all_databases(self, error_logger, cid, ec_id):
        # submit first command to find number of databases
        # DBR 7.0 changes databaseName to namespace for the return value of show databases
        all_dbs_cmd = 'all_dbs = [x.databaseName for x in spark.sql("show databases").collect()]; print(len(all_dbs))'
        results = self.submit_command(cid, ec_id, all_dbs_cmd)
        if logging_utils.log_reponse_error(error_logger, results):
            raise ValueError(
                "Cannot identify number of databases due to the above error")
        num_of_dbs = ast.literal_eval(results['data'])
        batch_size = 100  # batch size to iterate over databases
        num_of_buckets = (num_of_dbs // batch_size
                          ) + 1  # number of slices of the list to take

        all_dbs = []
        for m in range(0, num_of_buckets):
            db_slice = 'print(all_dbs[{0}:{1}])'.format(
                batch_size * m, batch_size * (m + 1))
            results = self.submit_command(cid, ec_id, db_slice)
            db_names = ast.literal_eval(results['data'])
            for db in db_names:
                all_dbs.append(db)
                logging.info("Database: {0}".format(db))
        return all_dbs
Пример #23
0
    def import_cluster_configs(self,
                               log_file='clusters.log',
                               acl_log_file='acl_clusters.log',
                               filter_user=None):
        """
        Import cluster configs and update appropriate properties / tags in the new env
        :param log_file:
        :return:
        """
        cluster_log = self.get_export_dir() + log_file
        acl_cluster_log = self.get_export_dir() + acl_log_file
        if not os.path.exists(cluster_log):
            logging.info("No clusters to import.")
            return
        current_cluster_names = set([
            x.get('cluster_name', None) for x in self.get_cluster_list(False)
        ])
        old_2_new_policy_ids = self.get_new_policy_id_dict(
        )  # dict of {old_id : new_id}
        error_logger = logging_utils.get_error_logger(
            wmconstants.WM_IMPORT, wmconstants.CLUSTER_OBJECT,
            self.get_export_dir())
        checkpoint_cluster_configs_set = self._checkpoint_service.get_checkpoint_key_set(
            wmconstants.WM_IMPORT, wmconstants.CLUSTER_OBJECT)
        # get instance pool id mappings
        with open(cluster_log, 'r') as fp:
            for line in fp:
                cluster_conf = json.loads(line)
                if 'cluster_id' in cluster_conf and checkpoint_cluster_configs_set.contains(
                        cluster_conf['cluster_id']):
                    continue
                cluster_name = cluster_conf['cluster_name']
                if cluster_name in current_cluster_names:
                    logging.info(
                        "Cluster already exists, skipping: {0}".format(
                            cluster_name))
                    continue
                cluster_creator = cluster_conf.pop('creator_user_name')
                if 'policy_id' in cluster_conf:
                    old_policy_id = cluster_conf['policy_id']
                    cluster_conf['policy_id'] = old_2_new_policy_ids[
                        old_policy_id]
                # check for instance pools and modify cluster attributes
                if 'instance_pool_id' in cluster_conf:
                    new_cluster_conf = self.cleanup_cluster_pool_configs(
                        cluster_conf, cluster_creator)
                else:
                    # update cluster configs for non-pool clusters
                    # add original creator tag to help with DBU tracking
                    if 'custom_tags' in cluster_conf:
                        tags = cluster_conf['custom_tags']
                        tags['OriginalCreator'] = cluster_creator
                        cluster_conf['custom_tags'] = tags
                    else:
                        cluster_conf['custom_tags'] = {
                            'OriginalCreator': cluster_creator
                        }
                    new_cluster_conf = cluster_conf
                print("Creating cluster: {0}".format(
                    new_cluster_conf['cluster_name']))
                cluster_resp = self.post('/clusters/create', new_cluster_conf)
                if cluster_resp['http_status_code'] == 200:
                    stop_resp = self.post(
                        '/clusters/delete',
                        {'cluster_id': cluster_resp['cluster_id']})
                    if 'pinned_by_user_name' in cluster_conf:
                        pin_resp = self.post(
                            '/clusters/pin',
                            {'cluster_id': cluster_resp['cluster_id']})
                    if 'cluster_id' in cluster_conf:
                        checkpoint_cluster_configs_set.write(
                            cluster_conf['cluster_id'])
                else:
                    logging_utils.log_reponse_error(error_logger, cluster_resp)
                    print(cluster_resp)

        # TODO: May be put it into a separate step to make it more rerunnable.
        self._log_cluster_ids_and_original_creators(log_file)

        # add cluster ACLs
        # loop through and reapply cluster ACLs
        with open(acl_cluster_log, 'r') as acl_fp:
            for x in acl_fp:
                data = json.loads(x)
                if 'object_id' in data and checkpoint_cluster_configs_set.contains(
                        data['object_id']):
                    continue
                cluster_name = data['cluster_name']
                print(f'Applying acl for {cluster_name}')
                acl_args = {
                    'access_control_list':
                    self.build_acl_args(data['access_control_list'])
                }
                cid = self.get_cluster_id_by_name(cluster_name)
                if cid is None:
                    error_message = f'Cluster id must exist in new env for cluster_name: {cluster_name}. ' \
                                    f'Re-import cluster configs.'
                    raise ValueError(error_message)
                api = f'/preview/permissions/clusters/{cid}'
                resp = self.put(api, acl_args)
                if not logging_utils.log_reponse_error(error_logger, resp):
                    if 'object_id' in data:
                        checkpoint_cluster_configs_set.write(data['object_id'])
                print(resp)
Пример #24
0
        def _upload_all_files(root, subdirs, files):
            '''
            Upload all files in parallel in root (current) directory.
            '''
            # replace the local directory with empty string to get the notebook workspace directory
            nb_dir = '/' + root.replace(src_dir, '')
            upload_dir = nb_dir
            if not nb_dir == '/':
                upload_dir = nb_dir + '/'
            if self.is_user_ws_item(upload_dir):
                ws_user = self.get_user(upload_dir)
                if archive_missing:
                    if ws_user in archive_users:
                        upload_dir = upload_dir.replace('Users', 'Archive', 1)
                    elif not self.does_user_exist(ws_user):
                        # add the user to the cache / set of missing users
                        logging.info(
                            "User workspace does not exist, adding to archive cache: {0}"
                            .format(ws_user))
                        archive_users.add(ws_user)
                        # append the archive path to the upload directory
                        upload_dir = upload_dir.replace('Users', 'Archive', 1)
                    else:
                        logging.info(
                            "User workspace exists: {0}".format(ws_user))
                elif not self.does_user_exist(ws_user):
                    logging.info(
                        "User {0} is missing. "
                        "Please re-run with --archive-missing flag "
                        "or first verify all users exist in the new workspace".
                        format(ws_user))
                    return
                else:
                    logging.info("Uploading for user: {0}".format(ws_user))
            # make the top level folder before uploading files within the loop
            if not self.is_user_ws_root(upload_dir):
                # if it is not the /Users/[email protected]/ root path, don't create the folder
                resp_mkdirs = self.post(WS_MKDIRS, {'path': upload_dir})
                if 'error_code' in resp_mkdirs:
                    resp_mkdirs['path'] = upload_dir
                    logging_utils.log_reponse_error(error_logger, resp_mkdirs)

            def _file_upload_helper(f):
                logging.info("Uploading: {0}".format(f))
                # create the local file path to load the DBC file
                local_file_path = os.path.join(root, f)
                # create the ws full file path including filename
                ws_file_path = upload_dir + f
                if checkpoint_notebook_set.contains(ws_file_path):
                    return
                # generate json args with binary data for notebook to upload to the workspace path
                nb_input_args = self.get_user_import_args(
                    local_file_path, ws_file_path)
                # call import to the workspace
                if self.is_verbose():
                    logging.info("Path: {0}".format(nb_input_args['path']))
                resp_upload = self.post(WS_IMPORT, nb_input_args)
                if 'error_code' in resp_upload:
                    resp_upload['path'] = ws_file_path
                    logging.info(f'Error uploading file: {ws_file_path}')
                    logging_utils.log_reponse_error(error_logger, resp_upload)
                else:
                    checkpoint_notebook_set.write(ws_file_path)

            with ThreadPoolExecutor(max_workers=num_parallel) as executor:
                futures = [
                    executor.submit(_file_upload_helper, file)
                    for file in files
                ]
                concurrent.futures.wait(futures, return_when="FIRST_EXCEPTION")
                propagate_exceptions(futures)
Пример #25
0
    def import_all_secrets(self, log_dir='secret_scopes/'):
        scopes_dir = self.get_export_dir() + log_dir
        error_logger = logging_utils.get_error_logger(
            wmconstants.WM_IMPORT, wmconstants.SECRET_OBJECT,
            self.get_export_dir())
        scopes_acl_dict = self.load_acl_dict()
        for root, subdirs, files in self.walk(scopes_dir):
            for scope_name in files:
                file_path = root + scope_name
                # print('Log file: ', file_path)
                # check if scopes acls are empty, then skip
                if scopes_acl_dict.get(scope_name, None) is None:
                    print(
                        "Scope is empty with no manage permissions. Skipping..."
                    )
                    continue
                # check if users has can manage perms then we can add during creation time
                has_user_manage = self.has_users_can_manage_permission(
                    scope_name, scopes_acl_dict)
                create_scope_args = {'scope': scope_name}
                if has_user_manage:
                    create_scope_args['initial_manage_principal'] = 'users'
                other_permissions = self.get_all_other_permissions(
                    scope_name, scopes_acl_dict)
                create_resp = self.post('/secrets/scopes/create',
                                        create_scope_args)
                logging_utils.log_reponse_error(
                    error_logger,
                    create_resp,
                    ignore_error_list=['RESOURCE_ALREADY_EXISTS'])
                if other_permissions:
                    # use this dict minus the `users:MANAGE` permissions and apply the other permissions to the scope
                    for perm, principal_list in other_permissions.items():
                        put_acl_args = {
                            "scope": scope_name,
                            "permission": perm
                        }
                        for x in principal_list:
                            put_acl_args["principal"] = x
                            logging.info(put_acl_args)
                            put_resp = self.post('/secrets/acls/put',
                                                 put_acl_args)
                            logging_utils.log_reponse_error(
                                error_logger, put_resp)
                # loop through the scope and create the k/v pairs
                with open(file_path, 'r') as fp:
                    for s in fp:
                        s_dict = json.loads(s)
                        k = s_dict.get('name')
                        v = s_dict.get('value')
                        if 'WARNING: skipped' in v:
                            error_logger.error(
                                f"Skipping scope {scope_name} as value is corrupted due to being too large \n"
                            )
                            continue
                        try:
                            put_secret_args = {
                                'scope':
                                scope_name,
                                'key':
                                k,
                                'string_value':
                                base64.b64decode(
                                    v.encode('ascii')).decode('ascii')
                            }
                            put_resp = self.post('/secrets/put',
                                                 put_secret_args)
                            logging_utils.log_reponse_error(
                                error_logger, put_resp)
                        except Exception as error:
                            if "Invalid base64-encoded string" in str(
                                    error) or 'decode' in str(
                                        error) or "padding" in str(error):
                                error_msg = f"secret_scope: {scope_name} has invalid invalid data characters: {str(error)} skipping.. and logging to error file."
                                logging.error(error_msg)
                                error_logger.error(error_msg)

                            else:
                                raise error
Пример #26
0
    def import_job_configs(self, log_file='jobs.log', acl_file='acl_jobs.log'):
        jobs_log = self.get_export_dir() + log_file
        acl_jobs_log = self.get_export_dir() + acl_file
        error_logger = logging_utils.get_error_logger(wmconstants.WM_IMPORT,
                                                      wmconstants.JOB_OBJECT,
                                                      self.get_export_dir())
        if not os.path.exists(jobs_log):
            logging.info("No job configurations to import.")
            return
        # get an old cluster id to new cluster id mapping object
        cluster_mapping = self.get_cluster_id_mapping()
        old_2_new_policy_ids = self.get_new_policy_id_dict(
        )  # dict { old_policy_id : new_policy_id }
        checkpoint_job_configs_set = self._checkpoint_service.get_checkpoint_key_set(
            wmconstants.WM_IMPORT, wmconstants.JOB_OBJECT)

        def adjust_ids_for_cluster(settings):  #job_settings or task_settings
            if 'existing_cluster_id' in settings:
                old_cid = settings['existing_cluster_id']
                # set new cluster id for existing cluster attribute
                new_cid = cluster_mapping.get(old_cid, None)
                if not new_cid:
                    logging.info(
                        "Existing cluster has been removed. Resetting job to use new cluster."
                    )
                    settings.pop('existing_cluster_id')
                    settings[
                        'new_cluster'] = self.get_jobs_default_cluster_conf()
                else:
                    settings['existing_cluster_id'] = new_cid
            else:  # new cluster config
                cluster_conf = settings['new_cluster']
                if 'policy_id' in cluster_conf:
                    old_policy_id = cluster_conf['policy_id']
                    cluster_conf['policy_id'] = old_2_new_policy_ids[
                        old_policy_id]
                # check for instance pools and modify cluster attributes
                if 'instance_pool_id' in cluster_conf:
                    new_cluster_conf = self.cleanup_cluster_pool_configs(
                        cluster_conf, job_creator, True)
                else:
                    new_cluster_conf = cluster_conf
                settings['new_cluster'] = new_cluster_conf

        with open(jobs_log, 'r') as fp:
            for line in fp:
                job_conf = json.loads(line)
                # need to do str(...), otherwise the job_id is recognized as integer which becomes
                # str vs int which never matches.
                # (in which case, the checkpoint never recognizes that the job_id is already checkpointed)
                if 'job_id' in job_conf and checkpoint_job_configs_set.contains(
                        str(job_conf['job_id'])):
                    continue
                job_creator = job_conf.get('creator_user_name', '')
                job_settings = job_conf['settings']
                job_schedule = job_settings.get('schedule', None)
                if job_schedule:
                    # set all imported jobs as paused
                    job_schedule['pause_status'] = 'PAUSED'
                    job_settings['schedule'] = job_schedule
                if 'format' not in job_settings or job_settings.get(
                        'format') == 'SINGLE_TASK':
                    adjust_ids_for_cluster(job_settings)
                else:
                    for task_settings in job_settings.get('tasks', []):
                        adjust_ids_for_cluster(task_settings)

                logging.info("Current Job Name: {0}".format(
                    job_conf['settings']['name']))
                # creator can be none if the user is no longer in the org. see our docs page
                create_resp = self.post('/jobs/create', job_settings)
                if logging_utils.check_error(create_resp):
                    logging.info(
                        "Resetting job to use default cluster configs due to expired configurations."
                    )
                    job_settings[
                        'new_cluster'] = self.get_jobs_default_cluster_conf()
                    create_resp_retry = self.post('/jobs/create', job_settings)
                    if not logging_utils.log_reponse_error(
                            error_logger, create_resp_retry):
                        if 'job_id' in job_conf:
                            checkpoint_job_configs_set.write(
                                job_conf["job_id"])
                    else:
                        raise RuntimeError(
                            "Import job has failed. Refer to the previous log messages to investigate."
                        )

                else:
                    if 'job_id' in job_conf:
                        checkpoint_job_configs_set.write(job_conf["job_id"])

        # update the jobs with their ACLs
        with open(acl_jobs_log, 'r') as acl_fp:
            job_id_by_name = self.get_job_id_by_name()
            for line in acl_fp:
                acl_conf = json.loads(line)
                if 'object_id' in acl_conf and checkpoint_job_configs_set.contains(
                        acl_conf['object_id']):
                    continue
                current_job_id = job_id_by_name[acl_conf['job_name']]
                job_path = f'jobs/{current_job_id}'  # contains `/jobs/{job_id}` path
                api = f'/preview/permissions/{job_path}'
                # get acl permissions for jobs
                acl_perms = self.build_acl_args(
                    acl_conf['access_control_list'], True)
                acl_create_args = {'access_control_list': acl_perms}
                acl_resp = self.patch(api, acl_create_args)
                if not logging_utils.log_reponse_error(
                        error_logger, acl_resp) and 'object_id' in acl_conf:
                    checkpoint_job_configs_set.write(acl_conf['object_id'])
                else:
                    raise RuntimeError(
                        "Import job has failed. Refer to the previous log messages to investigate."
                    )
        # update the imported job names
        self.update_imported_job_names(error_logger,
                                       checkpoint_job_configs_set)
Пример #27
0
    def import_hive_metastore(self,
                              cluster_name=None,
                              metastore_dir='metastore/',
                              views_dir='metastore_views/',
                              has_unicode=False,
                              should_repair_table=False):
        metastore_local_dir = self.get_export_dir() + metastore_dir
        metastore_view_dir = self.get_export_dir() + views_dir
        error_logger = logging_utils.get_error_logger(
            wmconstants.WM_IMPORT, wmconstants.METASTORE_TABLES,
            self.get_export_dir())
        checkpoint_metastore_set = self._checkpoint_service.get_checkpoint_key_set(
            wmconstants.WM_IMPORT, wmconstants.METASTORE_TABLES)
        os.makedirs(metastore_view_dir, exist_ok=True)
        (cid, ec_id) = self.get_or_launch_cluster(cluster_name)
        # get local databases
        db_list = self.listdir(metastore_local_dir)
        # make directory in DBFS root bucket path for tmp data
        self.post('/dbfs/mkdirs', {'path': '/tmp/migration/'})
        # iterate over the databases saved locally
        all_db_details_json = self.get_database_detail_dict()
        for db_name in db_list:
            # create a dir to host the view ddl if we find them
            os.makedirs(metastore_view_dir + db_name, exist_ok=True)
            # get the local database path to list tables
            local_db_path = metastore_local_dir + db_name
            # get a dict of the database attributes
            database_attributes = all_db_details_json.get(db_name, {})
            if not database_attributes:
                logging.info(all_db_details_json)
                raise ValueError(
                    'Missing Database Attributes Log. Re-run metastore export')
            create_db_resp = self.create_database_db(db_name, ec_id, cid,
                                                     database_attributes)
            if logging_utils.log_reponse_error(error_logger, create_db_resp):
                logging.error(
                    f"Failed to create database {db_name} during metastore import. Exiting Import."
                )
                return
            db_path = database_attributes.get('Location')
            if os.path.isdir(local_db_path):
                # all databases should be directories, no files at this level
                # list all the tables in the database local dir
                tables = self.listdir(local_db_path)
                for tbl_name in tables:
                    # build the path for the table where the ddl is stored
                    full_table_name = f"{db_name}.{tbl_name}"
                    if not checkpoint_metastore_set.contains(full_table_name):
                        logging.info(f"Importing table {full_table_name}")
                        local_table_ddl = metastore_local_dir + db_name + '/' + tbl_name
                        if not self.move_table_view(db_name, tbl_name,
                                                    local_table_ddl):
                            # we hit a table ddl here, so we apply the ddl
                            resp = self.apply_table_ddl(
                                local_table_ddl, ec_id, cid, db_path,
                                has_unicode)
                            if not logging_utils.log_reponse_error(
                                    error_logger, resp):
                                checkpoint_metastore_set.write(full_table_name)
                        else:
                            logging.info(
                                f'Moving view ddl to re-apply later: {db_name}.{tbl_name}'
                            )
            else:
                logging.error(
                    "Error: Only databases should exist at this level: {0}".
                    format(db_name))
            self.delete_dir_if_empty(metastore_view_dir + db_name)
        views_db_list = self.listdir(metastore_view_dir)
        for db_name in views_db_list:
            local_view_db_path = metastore_view_dir + db_name
            database_attributes = all_db_details_json.get(db_name, '')
            db_path = database_attributes.get('Location')
            if os.path.isdir(local_view_db_path):
                views = self.listdir(local_view_db_path)
                for view_name in views:
                    full_view_name = f'{db_name}.{view_name}'
                    if not checkpoint_metastore_set.contains(full_view_name):
                        logging.info(f"Importing view {full_view_name}")
                        local_view_ddl = metastore_view_dir + db_name + '/' + view_name
                        resp = self.apply_table_ddl(local_view_ddl, ec_id, cid,
                                                    db_path, has_unicode)
                        if logging_utils.log_reponse_error(error_logger, resp):
                            checkpoint_metastore_set.write(full_view_name)
                        logging.info(resp)

        # repair legacy tables
        if should_repair_table:
            self.report_legacy_tables_to_fix()
            self.repair_legacy_tables(cluster_name)