Exemple #1
0
class Registry(object):
    """CFDE Registry binding.

    """
    def __init__(self,
                 scheme='https',
                 servername='app.nih-cfde.org',
                 catalog='registry',
                 credentials=None,
                 session_config=None):
        """Bind to specified registry.

        Note: this binding operates as an authenticated client
        identity and may expose different capabilities depending on
        the client's role within the organization.
        """
        if credentials is None:
            credentials = get_credential(servername)
        if not session_config:
            session_config = DEFAULT_SESSION_CONFIG.copy()
        session_config["allow_retry_on_all_methods"] = True
        self._catalog = ErmrestCatalog(scheme,
                                       servername,
                                       catalog,
                                       credentials,
                                       session_config=session_config)
        self._builder = self._catalog.getPathBuilder()

    def validate_dcc_id(self, dcc_id, submitting_user):
        """Validate that user has submitter role with this DCC according to registry.

        :param dcc_id: The dcc.id key of the DCC in the registry.
        :param submitting_user: The WebauthnUser representation of the authenticated submission user.

        Raises UnknownDccId for invalid DCC identifiers.
        Raises Forbidden if submitting_user is not a submitter for the named DCC.
        """
        rows = self.get_dcc(dcc_id)
        if len(rows) < 1:
            raise exception.UnknownDccId(dcc_id)
        self.enforce_dcc_submission(dcc_id, submitting_user)

    def _get_entity(self, table_name, id=None):
        """Get one or all entity records from a registry table.

        :param table_name: The registry table to access.
        :param id: A key to retrieve one row (default None retrieves all)
        """
        path = self._builder.CFDE.tables[table_name].path
        if id is not None:
            path = path.filter(path.table_instances[table_name].
                               column_definitions['id'] == id)
        return list(path.entities().fetch())

    def list_datapackages(self):
        """Get a list of all datapackage submissions in the registry

        """
        return self._get_entity('datapackage')

    def get_latest_approved_datapackages(self,
                                         need_dcc_appr=True,
                                         need_cfde_appr=True):
        """Get a map of latest datapackages approved for release for each DCC id."""
        path = self._builder.CFDE.tables['datapackage'].path
        status = path.datapackage.status
        path = path.filter(
            (status == terms.cfde_registry_dp_status.content_ready)
            | (status == terms.cfde_registry_dp_status.release_pending))
        if need_dcc_appr:
            path = path.filter(path.datapackage.dcc_approval_status ==
                               terms.cfde_registry_decision.approved)
        if need_cfde_appr:
            path = path.filter(path.datapackage.cfde_approval_status ==
                               terms.cfde_registry_decision.approved)
        res = {}
        for row in path.entities().sort(path.datapackage.submitting_dcc,
                                        path.datapackage.submission_time.desc):
            if row['submitting_dcc'] not in res:
                res[row['submitting_dcc']] = row
        return res

    def get_datapackage(self, id):
        """Get datapackage by submission id or raise exception.
        
        :param id: The datapackage.id key for the submission in the registry

        Raises DatapackageUnknown if record is not found.
        """
        rows = self._get_entity('datapackage', id)
        if len(rows) < 1:
            raise exception.DatapackageUnknown(
                'Datapackage "%s" not found in registry.' % (id, ))
        return rows[0]

    def get_datapackage_table(self, datapackage, position):
        """Get datapackage by submission id or raise exception.

        :param datapackage: The datapackage.id key for the submission in the registry
        :param position: The 0-based index of the table in the datapackage's list of resources

        Raises IndexError if record is not found.
        """
        path = self._builder.CFDE.datapackage_table.path
        path = path.filter(path.datapackage_table.datapackage == datapackage)
        path = path.filter(path.datapackage_table.position == position)
        rows = list(path.entities().fetch())
        if len(rows) < 1:
            raise IndexError(
                'Datapackage table ("%s", %d) not found in registry.' %
                (datapackage, position))
        return rows[0]

    def register_release(self, id, dcc_datapackages, description=None):
        """Idempotently register new release in registry, returning (release row, dcc_datapackages).

        :param id: The release.id for the new record
        :param dcc_datapackages: A dict mapping {dcc_id: datapackage, ...} for constituents
        :param description: A human-readable description of this release

        The constituents are a set of datapackage records (dicts) as
        returned by the get_datapackage() method. The dcc_id key MUST
        match the submitting_dcc of the record.

        For repeat calls on existing releases, the definition will be
        updated if the release is still in the planning state, but a
        StateError will be raised if it is no longer in planning state.

        """
        for dcc_id, dp in dcc_datapackages.items():
            if dcc_id != dp['submitting_dcc']:
                raise ValueError(
                    'Mismatch in dcc_datapackages DCC IDs %s != %s' %
                    (dcc_id, dp['submitting_dcc']))

        try:
            rel, old_dcc_dps = self.get_release(id)
        except exception.ReleaseUnknown:
            # create new release record
            newrow = {
                'id': id,
                'status': terms.cfde_registry_rel_status.planning,
                'description':
                None if description is nochange else description,
            }
            defaults = [
                cname for cname in
                self._builder.CFDE.release.column_definitions.keys()
                if cname not in newrow
            ]
            logger.info('Registering new release %s' % (id, ))
            self._catalog.post('/entity/CFDE:release?defaults=%s' %
                               (','.join(defaults), ),
                               json=[newrow])
            rel, old_dcc_dps = self.get_release(id)

        if rel['status'] != terms.cfde_registry_rel_status.planning:
            raise exception.StateError(
                'Idempotent registration disallowed on existing release %(id)s with status=%(status)s'
                % rel)

        # prepare for idempotent updates
        old_dp_ids = {dp['id'] for dp in old_dcc_dps.values()}
        dp_ids = {dp['id'] for dp in dcc_datapackages.values()}
        datapackages = {dp['id']: dp for dp in dcc_datapackages.values()}

        # idempotently revise description
        if rel['description'] != description:
            logger.info('Updating release %s description: %s' % (
                id,
                description,
            ))
            self.update_release(id, description=description)

        # find currently registered constituents
        path = self._builder.CFDE.dcc_release_datapackage.path
        path = path.filter(path.dcc_release_datapackage.release == id)
        old_dp_ids = {row['datapackage'] for row in path.entities().fetch()}

        # remove stale consituents
        for dp_id in old_dp_ids.difference(dp_ids):
            logger.info('Removing constituent datapackage %s from release %s' %
                        (dp_id, id))
            self._catalog.delete(
                '/entity/CFDE:dcc_release_datapackage/release=%s&datapackage=%s'
                % (
                    urlquote(id),
                    urlquote(dp_id),
                ))

        # add new consituents
        new_dp_ids = dp_ids.difference(old_dp_ids)
        if new_dp_ids:
            logger.info('Adding constituent datapackages %s to release %s' %
                        (new_dp_ids, id))
            self._catalog.post('/entity/CFDE:dcc_release_datapackage',
                               json=[{
                                   'dcc':
                                   datapackages[dp_id]['submitting_dcc'],
                                   'release':
                                   id,
                                   'datapackage':
                                   dp_id,
                               } for dp_id in new_dp_ids])

        # return registry content
        return self.get_release(id)

    def get_release(self, id):
        """Get release by submission id or raise exception, returning (release_row, dcc_datapackages).
        
        :param id: The release.id key for the release definition in the registry

        Raises ReleaseUnknown if record is not found.
        """
        rows = self._get_entity('release', id)
        if len(rows) < 1:
            raise exception.ReleaseUnknown(
                'Release "%s" not found in registry.' % (id, ))
        rel = rows[0]
        path = self._builder.CFDE.dcc_release_datapackage.path
        path = path.filter(path.dcc_release_datapackage.release == id)
        path = path.link(self._builder.CFDE.datapackage)
        return rel, {
            row['submitting_dcc']: row
            for row in path.entities().fetch()
        }

    def register_datapackage(self, id, dcc_id, submitting_user, archive_url):
        """Idempotently register new submission in registry.

        :param id: The datapackage.id for the new record
        :param dcc_id: The datapackage.submitting_dcc for the new record
        :param submitting_user: The datapackage.submitting_user for the new record
        :param archive_url: The datapackage.datapackage_url for the new record

        May raise non-CfdeError exceptions on operational errors.
        """
        try:
            return self.get_datapackage(id)
        except exception.DatapackageUnknown:
            pass

        # poke the submitting user into the registry's user-tracking table in case they don't exist
        # this acts as controlled domain table for submitting_user fkeys
        self._catalog.post('/entity/public:ERMrest_Client?onconflict=skip',
                           json=[{
                               'ID': submitting_user.webauthn_id,
                               'Display_Name': submitting_user.display_name,
                               'Full_Name': submitting_user.full_name,
                               'Email': submitting_user.email,
                               'Client_Object': {
                                   'id': submitting_user.webauthn_id,
                                   'display_name':
                                   submitting_user.display_name,
                               }
                           }])

        newrow = {
            "id": id,
            "submitting_dcc": dcc_id,
            "submitting_user": submitting_user.webauthn_id,
            "datapackage_url": archive_url,
            # we need to supply these unless catalog starts giving default values for us
            "submission_time": datetime.datetime.utcnow().isoformat(),
            "status": terms.cfde_registry_dp_status.submitted,
        }
        defaults = [
            cname for cname in
            self._builder.CFDE.datapackage.column_definitions.keys()
            if cname not in newrow
        ]
        self._catalog.post('/entity/CFDE:datapackage?defaults=%s' %
                           (','.join(defaults), ),
                           json=[newrow])
        # kind of redundant, but make sure we round-trip this w/ server-applied defaults?
        return self.get_datapackage(id)

    def register_datapackage_table(self, datapackage, position, table_name):
        """Idempotently register new datapackage table in registry.

        :param datapackage: The datapackage.id for the containing datapackage
        :param position: The integer position of this table in the datapackage's list of resources
        :param table_name: The "name" field of the tabular resource

        """
        newrow = {
            'datapackage': datapackage,
            'position': position,
            'table_name': table_name,
            'status': terms.cfde_registry_dpt_status.enumerated,
            'num_rows': None,
            'diagnostics': None,
        }

        rows = self._catalog.post(
            '/entity/CFDE:datapackage_table?onconflict=skip',
            json=[newrow]).json()

        if len(rows) == 0:
            # row exits
            self.update_datapackage_table(
                datapackage,
                position,
                status=terms.cfde_registry_dpt_status.enumerated)

    def update_release(self,
                       id,
                       status=nochange,
                       description=nochange,
                       cfde_approval_status=nochange,
                       release_time=nochange,
                       ermrest_url=nochange,
                       browse_url=nochange,
                       summary_url=nochange,
                       diagnostics=nochange):
        """Idempotently update release metadata in registry.

        :param id: The release.id of the existing record to update
        :param status: The new release.status value (default nochange)
        :param description: The new release.description value (default nochange)
        :param cfde_approval_status: The new release.cfde_approval_status value (default nochange)
        :param release_time: The new release.release_time value (default nochange)
        :param ermrest_url: The new release.review_ermrest_url value (default nochange)
        :param browse_url: The new release.review_browse_url value (default nochange)
        :param summary_url: The new release.review_summary_url value (default nochange)
        :param diagnostics: The new release.diagnostics value (default nochange)

        The special `nochange` singleton value used as default for
        optional arguments represents the desire to keep whatever
        current value exists for that field in the registry.

        May raise non-CfdeError exceptions on operational errors.
        """
        if not isinstance(id, str):
            raise TypeError('expected id of type str, not %s' % (type(id), ))
        existing, existing_dcc_dps = self.get_release(id)
        changes = {
            k: v
            for k, v in {
                'status': status,
                'description': description,
                'cfde_approval_status': cfde_approval_status,
                'release_time': release_time,
                'ermrest_url': ermrest_url,
                'browse_url': browse_url,
                'summary_url': summary_url,
                'diagnostics': diagnostics,
            }.items() if v is not nochange and v != existing[k]
        }
        if not changes:
            return
        changes['id'] = id
        self._catalog.put('/attributegroup/CFDE:release/id;%s' %
                          (','.join([c
                                     for c in changes.keys() if c != 'id']), ),
                          json=[changes])

    def update_datapackage(self,
                           id,
                           status=nochange,
                           diagnostics=nochange,
                           review_ermrest_url=nochange,
                           review_browse_url=nochange,
                           review_summary_url=nochange):
        """Idempotently update datapackage metadata in registry.

        :param id: The datapackage.id of the existing record to update
        :param status: The new datapackage.status value (default nochange)
        :param diagnostics: The new datapackage.diagnostics value (default nochange)
        :param review_ermrest_url: The new datapackage.review_ermrest_url value (default nochange)
        :param review_browse_url: The new datapackage.review_browse_url value (default nochange)
        :param review_summary_url: The new datapackage.review_summary_url value (default nochange)

        The special `nochange` singleton value used as default for
        optional arguments represents the desire to keep whatever
        current value exists for that field in the registry.

        May raise non-CfdeError exceptions on operational errors.
        """
        if not isinstance(id, str):
            raise TypeError('expected id of type str, not %s' % (type(id), ))
        existing = self.get_datapackage(id)
        changes = {
            k: v
            for k, v in {
                'status': status,
                'diagnostics': diagnostics,
                'review_ermrest_url': review_ermrest_url,
                'review_browse_url': review_browse_url,
                'review_summary_url': review_summary_url,
            }.items() if v is not nochange and v != existing[k]
        }
        if not changes:
            return
        changes['id'] = id
        self._catalog.put('/attributegroup/CFDE:datapackage/id;%s' %
                          (','.join([c
                                     for c in changes.keys() if c != 'id']), ),
                          json=[changes])

    def update_datapackage_table(self,
                                 datapackage,
                                 position,
                                 status=nochange,
                                 num_rows=nochange,
                                 diagnostics=nochange):
        """Idempotently update datapackage_table metadata in registry.

        :param datapackage: The datapackage_table.datapackage key value
        :param position: The datapackage_table.position key value
        :param status: The new datapackage_table.status value (default nochange)
        :param num_rows: The new datapackage_table.num_rows value (default nochange)
        :Param diagnostics: The new datapackage_table.diagnostics value (default nochange)

        """
        if not isinstance(datapackage, str):
            raise TypeError('expected datapackage of type str, not %s' %
                            (type(datapackage), ))
        if not isinstance(position, int):
            raise TypeError('expected id of type int, not %s' %
                            (type(position), ))
        existing = self.get_datapackage_table(datapackage, position)
        changes = {
            k: v
            for k, v in {
                'status': status,
                'num_rows': num_rows,
                'diagnostics': diagnostics,
            }.items() if v is not nochange and v != existing[k]
        }
        if not changes:
            return
        changes.update({
            'datapackage': datapackage,
            'position': position,
        })
        self._catalog.put(
            '/attributegroup/CFDE:datapackage_table/datapackage,position;%s' %
            (','.join([
                c for c in changes.keys()
                if c not in {'datapackage', 'position'}
            ]), ),
            json=[changes])

    def get_dcc(self, dcc_id=None):
        """Get one or all DCC records from the registry.

        :param dcc_id: Optional dcc.id key string to limit results to single DCC (default None)

        Returns a list of dict-like records representing rows of the
        registry dcc table, optionally restricted to a specific dcc.id
        key.
        """
        return self._get_entity('dcc', dcc_id)

    def get_group(self, group_id=None):
        """Get one or all group records from the registry.

        :param group_id: Optional group.id key string to limit results to single group (default None)

        Returns a list of dict-like records representing rows of the
        registry group table, optionally restricted to a specific group.id
        key.
        """
        return self._get_entity('group', group_id)

    def get_group_role(self, role_id=None):
        """Get one or all group-role records from the registry.

        :param role_id: Optional group_role.id key string to limit results to single role (default None)

        Returns a list of dict-like records representing rows of the
        registry group_role table, optionally restricted to a specific
        group_role.id key.
        """
        return self._get_entity('group_role', role_id)

    def get_groups_by_dcc_role(self, role_id=None, dcc_id=None):
        """Get groups by DCC x role for one or all roles and DCCs.

        :param role_id: Optional role.id key string to limit results to a single group role (default None)
        :param dcc_id: Optional dcc.id key string to limit results to a single DCC (default None)

        Returns a list of dict-like records associating a DCC id, a
        role ID, and a list of group IDs suitable as an ACL for that
        particular dcc-role combination.
        """
        # find range of possible values
        dccs = {row['id']: row for row in self.get_dcc(dcc_id)}
        roles = {row['id']: row for row in self.get_group_role(role_id)}

        # find mapped groups (an inner join)
        path = self._builder.CFDE.dcc_group_role.path.link(
            self._builder.CFDE.group)
        if role_id is not None:
            path = path.filter(path.dcc_group_role.role == role_id)
        if dcc_id is not None:
            path = path.filter(path.dcc_group_role.dcc == dcc_id)
        dcc_roles = {
            (row['dcc'], row['role']): row
            for row in path.groupby(path.dcc_group_role.dcc, path.dcc_group_role.role) \
            .attributes(ArrayD(path.group).alias("groups")) \
            .fetch()
        }

        # as a convenience for simple consumers, emulate a full outer
        # join pattern to return empty lists for missing combinations
        return [
            (
                dcc_roles[(dcc_id, role_id)] \
                if (dcc_id, role_id) in dcc_roles \
                else {"dcc": dcc_id, "role": role_id, "groups": []}
            )
            for dcc_id in dccs
            for role_id in roles
        ]

    def get_dcc_acl(self, dcc_id, role_id):
        """Get groups for one DCC X group_role as a webauthn-style ACL.

        :param dcc_id: A dcc.id key known by the registry.
        :param role_id: A group_role.id key known by the registry.

        Returns a list of webauthn ID strings as an access control
        list suitable for intersection tests with
        WebauthnUser.acl_authz_test().
        """
        acl = set()
        for row in self.get_groups_by_dcc_role(role_id, dcc_id):
            acl.update({grp['webauthn_id'] for grp in row['groups']})
        return list(sorted(acl))

    def enforce_dcc_submission(self, dcc_id, submitting_user):
        """Verify that submitting_user is authorized to submit datapackages for dcc_id.

        :param dcc_id: The dcc.id key of the DCC in the registry
        :param submitting_user: The WebauthnUser representation of the user context.

        Raises Forbidden if user does not have submitter role for DCC.
        """
        submitting_user.acl_authz_test(
            self.get_dcc_acl(dcc_id, terms.cfde_registry_grp_role.submitter),
            'Submission to DCC %s is forbidden' % (dcc_id, ))

    @classmethod
    def dump_onboarding(self, registry_datapackage):
        """Dump onboarding info about DCCs in registry"""
        resources = [
            resource
            for resource in registry_datapackage.package_def['resources']
            if resource['name'] in {'dcc', 'group', 'dcc_group_role'}
        ]
        registry_datapackage.dump_data_files(resources=resources)
Exemple #2
0
class AclConfig:
    NC_NAME = 'name'
    GC_NAME = 'groups'
    ACL_TYPES = [
        "catalog_acl", "schema_acls", "table_acls", "column_acls",
        "foreign_key_acls"
    ]
    GLOBUS_PREFIX = 'https://auth.globus.org/'
    ROBOT_PREFIX_FORMAT = 'https://{server}/webauthn_robot/'

    def __init__(self,
                 server,
                 catalog_id,
                 config_file,
                 credentials,
                 schema_name=None,
                 table_name=None,
                 verbose=False):
        self.config = json.load(open(config_file))
        self.ignored_schema_patterns = []
        self.verbose = verbose
        self.server = server
        self.catalog_id = catalog_id
        ip = self.config.get("ignored_schema_patterns")
        if ip is not None:
            for p in ip:
                self.ignored_schema_patterns.append(re.compile(p))
        self.acl_specs = {"catalog_acl": self.config.get("catalog_acl")}
        for key in self.ACL_TYPES:
            if key != "catalog_acl":
                self.acl_specs[key] = self.make_speclist(key)
        self.groups = self.config.get("groups")
        self.expand_groups()
        self.acl_definitions = self.config.get("acl_definitions")
        self.expand_acl_definitions()
        self.acl_bindings = self.config.get("acl_bindings")
        self.invalidate_bindings = self.config.get("invalidate_bindings")

        old_catalog = ErmrestCatalog('https', self.server, self.catalog_id,
                                     credentials)
        self.saved_toplevel_config = ConfigUtil.find_toplevel_node(
            old_catalog.getCatalogModel(), schema_name, table_name)
        self.catalog = ErmrestCatalog('https', self.server, self.catalog_id,
                                      credentials)
        self.toplevel_config = ConfigUtil.find_toplevel_node(
            self.catalog.getCatalogModel(), schema_name, table_name)

    def make_speclist(self, name):
        d = self.config.get(name)
        if d is None:
            d = dict()
        return ACLSpecList(d)

    def add_node_acl(self, node, acl_name):
        acl = self.acl_definitions.get(acl_name)
        if acl is None:
            raise ValueError(
                "no acl set called '{name}'".format(name=acl_name))
        for k in acl.keys():
            node.acls[k] = acl[k]

    def add_node_acl_binding(self, node, table_node, binding_name):
        if not binding_name in self.acl_bindings:
            raise ValueError(
                "no acl binding called '{name}'".format(name=binding_name))
        binding = self.acl_bindings.get(binding_name)
        try:
            node.acl_bindings[binding_name] = self.expand_acl_binding(
                binding, table_node)
        except NoForeignKeyError as e:
            detail = ''
            if isinstance(node, ermrest_model.Column):
                detail = 'on column {n}'.format(n=node.name)
            elif isinstance(node, ermrest_model.ForeignKey):
                detail = 'on foreign key {s}.{n}'.format(s=node.names[0][0],
                                                         n=node.names[0][1])
            else:
                detail = ' {t}'.format(t=type(node))
            print("couldn't expand acl binding {b} {d} table {s}.{t}".format(
                b=binding_name,
                d=detail,
                s=table_node.schema.name,
                t=table_node.name))
            raise e

    def expand_acl_binding(self, binding, table_node):
        if not isinstance(binding, dict):
            return binding
        new_binding = dict()
        for k in binding.keys():
            if k == "projection":
                new_binding[k] = []
                for proj in binding.get(k):
                    new_binding[k].append(
                        self.expand_projection(proj, table_node))
            elif k == "scope_acl":
                new_binding[k] = self.get_group(binding.get(k))
            else:
                new_binding[k] = binding[k]
        return new_binding

    def expand_projection(self, proj, table_node):
        if isinstance(proj, dict):
            new_proj = dict()
            is_first_outbound = True
            for k in proj.keys():
                if k == "outbound_col":
                    if is_first_outbound:
                        is_first_outbound = False
                    else:
                        raise NotImplementedError(
                            "don't know how to expand 'outbound_col' on anything but the first entry in a projection; "
                            "use 'outbound' instead")
                    if table_node is None:
                        raise NotImplementedError(
                            "don't know how to expand 'outbound_col' in a foreign key acl/annotation; use 'outbound' "
                            "instead")
                    new_proj["outbound"] = self.expand_projection_column(
                        proj[k], table_node)
                    if new_proj["outbound"] is None:
                        return None
                else:
                    new_proj[k] = proj[k]
                    is_first_outbound = False
            return new_proj
        else:
            return proj

    def expand_projection_column(self, col_name, table_node):
        for fkey in table_node.foreign_keys:
            if len(fkey.foreign_key_columns) == 1:
                col = fkey.foreign_key_columns[0]
                if col.get("table_name") == table_node.name and col.get(
                        "schema_name") == table_node.schema.name and col.get(
                            "column_name") == col_name:
                    return fkey.names[0]
        raise NoForeignKeyError("can't find foreign key for column %I.%I(%I)",
                                table_node.schema.name, table_node.name,
                                col_name)

    def set_node_acl_bindings(self, node, table_node, binding_list,
                              invalidate_list):
        node.acl_bindings.clear()
        if binding_list is not None:
            for binding_name in binding_list:
                self.add_node_acl_binding(node, table_node, binding_name)
        if invalidate_list is not None:
            for binding_name in invalidate_list:
                if binding_list and binding_name in binding_list:
                    raise ValueError(
                        "Binding {b} appears in both acl_bindings and invalidate_bindings for table {s}.{t} node {n}"
                        .format(b=binding_name,
                                s=table_node.schema.name,
                                t=table_node.name,
                                n=node.name))
                node.acl_bindings[binding_name] = False

    def save_groups(self):
        glt = self.create_or_validate_group_table()
        if glt is not None and self.groups is not None:
            rows = []
            for name in self.groups.keys():
                row = {'name': name, 'groups': self.groups.get(name)}
                for c in ['RCB', 'RMB']:
                    if glt.getColumn(c) is not None:
                        row[c] = None
                rows.append(row)

            glt.upsertRows(self.catalog, rows)

    def create_or_validate_schema(self, schema_name):
        schema = self.catalog.getCatalogSchema()['schemas'].get(schema_name)
        if schema is None:
            self.catalog.post("/schema/{s}".format(s=schema_name))
        return self.catalog.getCatalogSchema()['schemas'].get(schema_name)

    def create_table(self, schema_name, table_name, table_spec, comment=None):
        if table_spec is None:
            table_spec = dict()
        if schema_name is None:
            return None
        table_spec["schema_name"] = schema_name
        table_spec["table_name"] = table_name
        if table_spec.get('comment') is None and comment is not None:
            table_spec['comment'] = comment
        if table_spec.get('kind') is None:
            table_spec['kind'] = 'table'
        self.catalog.post("/schema/{s}/table".format(s=schema_name),
                          json=table_spec)
        schema = self.catalog.getCatalogSchema()['schemas'].get(schema_name)
        return schema['tables'].get(table_name)

    def create_or_validate_group_table(self):
        glt_spec = self.config.get('group_list_table')
        if glt_spec is None:
            return None
        sname = glt_spec.get('schema')
        tname = glt_spec.get('table')
        if sname is None or tname is None:
            raise ValueError("group_list_table missing schema or table")
        schema = self.create_or_validate_schema(sname)
        assert schema is not None
        glt = Table(schema['tables'].get(tname))
        if glt == {}:
            glt_spec = ermrest_model.Table.define(
                tname,
                column_defs=[
                    ermrest_model.Column.define(
                        self.NC_NAME,
                        ermrest_model.builtin_types.text,
                        nullok=False,
                        comment=
                        'Name of grouplist, used in foreign keys. This table is maintained by the acl-config '
                        'program and should not be updated by hand.'),
                    ermrest_model.Column.define(
                        self.GC_NAME,
                        ermrest_model.builtin_types['text[]'],
                        nullok=True,
                        comment=
                        'List of groups. This table is maintained by the acl-config program and should not be '
                        'updated by hand.')
                ],
                key_defs=[
                    ermrest_model.Key.define([self.NC_NAME],
                                             constraint_names=[[
                                                 sname, "{t}_{c}_u".format(
                                                     t=tname, c=self.NC_NAME)
                                             ]])
                ],
                comment=
                "Named lists of groups used in ACLs. Maintained by the acl-config program. Do not update this "
                "table manually.",
                annotations={'tag:isrd.isi.edu,2016:generated': None})
            glt = Table(self.create_table(sname, tname, glt_spec))

        else:
            name_col = glt.getColumn(self.NC_NAME)
            if name_col is None:
                raise ValueError(
                    'table specified for group lists ({s}.{t}) lacks a "{n}" column'
                    .format(s=sname, t=tname, n=self.NC_NAME))
            if name_col.get('nullok'):
                raise ValueError(
                    "{n} column in group list table ({s}.{t}) allows nulls".
                    format(n=self.NC_NAME, s=sname, t=tname))

            nc_uniq = False
            for key in glt.get('keys'):
                cols = key.get('unique_columns')
                if len(cols) == 1 and cols[0] == self.NC_NAME:
                    nc_uniq = True
                    break
            if not nc_uniq:
                raise ValueError(
                    "{n} column in group list table ({s}.{t}) is not a key".
                    format(n=self.NC_NAME, s=sname, t=tname))

            val_col = glt.getColumn(self.GC_NAME)
            if val_col is None:
                raise ValueError(
                    'table specified for group lists ({s}.{t}) lacks a "{n}" column'
                    .format(s=sname, t=tname, n=self.GC_NAME))
        if glt == {}:
            return None
        else:
            return glt

    def set_node_acl(self, node, spec):
        node.acls.clear()
        acl_name = spec.get("acl")
        if acl_name is not None:
            self.add_node_acl(node, acl_name)

    def expand_groups(self):
        for group_name in self.groups.keys():
            self.expand_group(group_name)

    def get_group(self, group_name):
        group = self.groups.get(group_name)
        if group is None:
            group = [group_name]
        return group

    def validate_group(self, group):
        if group == '*':
            return
        elif group.startswith(self.GLOBUS_PREFIX):
            self.validate_globus_group(group)
        elif group.startswith(
                self.ROBOT_PREFIX_FORMAT.format(server=self.server)):
            self.validate_webauthn_robot(group)
        else:
            warnings.warn(
                "Can't determine format of group '{g}'".format(g=group))

    def validate_globus_group(self, group):
        guid = group[len(self.GLOBUS_PREFIX):]
        try:
            UUID(guid)
        except ValueError:
            raise ValueError(
                "Group '{g}' appears to be a malformed Globus group".format(
                    g=group))
        if self.verbose:
            print(
                "group '{g}' appears to be a syntactically-correct Globus group"
                .format(g=group))

    def validate_webauthn_robot(self, group):
        robot_name = group[
            len(self.ROBOT_PREFIX_FORMAT.format(server=self.server)):]
        if not robot_name:
            raise ValueError(
                "Group '{g}' appears to be a malformed webauthn robot identity"
                .format(g=group))
        if self.verbose:
            print(
                "group '{g}' appears to be a syntactically-correct webauthn robot identity"
                .format(g=group))

    def expand_group(self, group_name):
        groups = []
        for child_name in self.groups.get(group_name):
            child = self.groups.get(child_name)
            if child is None:
                self.validate_group(child_name)
                groups.append(child_name)
            else:
                self.expand_group(child_name)
                groups = groups + self.groups[child_name]
        self.groups[group_name] = list(set(groups))

    def expand_acl_definitions(self):
        for acl_name in self.acl_definitions.keys():
            self.expand_acl_definition(acl_name)

    def expand_acl_definition(self, acl_name):
        spec = self.acl_definitions.get(acl_name)
        for op_type in spec.keys():
            groups = []
            raw_groups = spec[op_type]
            if isinstance(raw_groups, list):
                for group_name in spec[op_type]:
                    groups = groups + self.get_group(group_name)
            else:
                groups = self.get_group(raw_groups)
            spec[op_type] = groups

    def set_table_acls(self, table):
        spec = self.acl_specs["table_acls"].find_best_table_spec(
            table.schema.name, table.name)
        table.acls.clear()
        table.acl_bindings.clear()
        if spec is not None:
            self.set_node_acl(table, spec)
            self.set_node_acl_bindings(table, table, spec.get("acl_bindings"),
                                       spec.get("invalidate_bindings"))
        if self.verbose:
            print("set table {s}.{t} acls to {a}, bindings to {b}".format(
                s=table.schema.name,
                t=table.name,
                a=str(table.acls),
                b=str(table.acl_bindings)))
        for column in table.column_definitions:
            self.set_column_acls(column, table)
        for fkey in table.foreign_keys:
            self.set_fkey_acls(fkey, table)

    def set_column_acls(self, column, table):
        spec = self.acl_specs["column_acls"].find_best_column_spec(
            column.table.schema.name, column.table.name, column.name)
        column.acls.clear()
        column.acl_bindings.clear()
        if spec is not None:
            self.set_node_acl(column, spec)
            self.set_node_acl_bindings(column, table, spec.get("acl_bindings"),
                                       spec.get("invalidate_bindings"))
        if self.verbose:
            print("set column {s}.{t}.{c} acls to {a}, bindings to {b}".format(
                s=column.table.schema.name,
                t=column.table.name,
                c=column.name,
                a=str(column.acls),
                b=str(column.acl_bindings)))

    def set_fkey_acls(self, fkey, table):
        spec = self.acl_specs["foreign_key_acls"].find_best_foreign_key_spec(
            fkey.table.schema.name, fkey.table.name, fkey.names)
        fkey.acls.clear()
        fkey.acl_bindings.clear()
        if spec is not None:
            self.set_node_acl(fkey, spec)
            self.set_node_acl_bindings(fkey, table, spec.get("acl_bindings"))
        if self.verbose:
            print("set fkey {f} acls to {a}, bindings to {b}".format(
                f=str(fkey.names), a=str(fkey.acls), b=str(fkey.acl_bindings)))

    def set_catalog_acls(self, catalog):
        spec = self.acl_specs["catalog_acl"]
        if spec is not None:
            catalog.acls.clear()
            self.set_node_acl(catalog, spec)
        if self.verbose:
            print("set catalog acls to {a}".format(a=str(catalog.acls)))
        for schema in self.toplevel_config.schemas.values():
            self.set_schema_acls(schema)

    def set_schema_acls(self, schema):
        for pattern in self.ignored_schema_patterns:
            if pattern.match(schema.name) is not None:
                print("ignoring schema {s}".format(s=schema.name))
                return
        spec = self.acl_specs["schema_acls"].find_best_schema_spec(schema.name)
        schema.acls.clear()
        if spec is not None:
            self.set_node_acl(schema, spec)
        if self.verbose:
            print("set schema {s} acls to {a}".format(s=schema.name,
                                                      a=str(schema.acls)))

        for table in schema.tables.values():
            self.set_table_acls(table)

    def set_acls(self):
        if isinstance(self.toplevel_config, ermrest_model.Model):
            self.set_catalog_acls(self.toplevel_config)
        elif isinstance(self.toplevel_config, ermrest_model.Schema):
            self.set_schema_acls(self.toplevel_config)
        elif isinstance(self.toplevel_config, ermrest_model.Table):
            self.set_table_acls(self.toplevel_config)
        else:
            raise ValueError("toplevel config is a {t}".format(
                t=str(type(self.toplevel_config))))

    def apply_acls(self):
        self.toplevel_config.apply(self.saved_toplevel_config)

    def dumps(self):
        """Dump a serialized (string) representation of the config.
        """
        return json.dumps(self.toplevel_config.prejson(), indent=2)
Exemple #3
0
class DerivaUpload(object):
    """
    Base class for upload tasks. Encapsulates a catalog instance and a hatrac store instance and provides some common
    and reusable functions.

    This class is not intended to be instantiated directly, but rather extended by a specific implementation.
    """

    DefaultConfigFileName = "config.json"
    DefaultServerListFileName = "servers.json"
    DefaultTransferStateFileName = "transfers.json"

    def __init__(self, config_file=None, credential_file=None, server=None):
        self.server_url = None
        self.catalog = None
        self.store = None
        self.config = None
        self.credentials = None
        self.asset_mappings = None
        self.transfer_state = dict()
        self.transfer_state_fp = None
        self.cancelled = False
        self.metadata = dict()

        self.file_list = OrderedDict()
        self.file_status = OrderedDict()
        self.skipped_files = set()
        self.override_config_file = config_file
        self.override_credential_file = credential_file
        self.server = self.getDefaultServer() if not server else server
        self.initialize()

    def __del__(self):
        self.cleanupTransferState()

    def initialize(self, cleanup=False):
        info = "%s v%s [Python %s, %s]" % (self.__class__.__name__, VERSION,
                                           platform.python_version(),
                                           platform.platform(aliased=True))
        logging.info("Initializing uploader: %s" % info)

        # cleanup invalidates the current configuration and credentials in addition to clearing internal state
        if cleanup:
            self.cleanup()
        # reset just clears the internal state
        else:
            self.reset()

        if not self.server:
            logging.warning(
                "A server was not specified and an internal default has not been set."
            )
            return

        # server variable initialization
        protocol = self.server.get('protocol', 'https')
        host = self.server.get('host', '')
        self.server_url = protocol + "://" + host
        catalog_id = self.server.get("catalog_id", "1")
        session_config = self.server.get('session')

        # overriden credential initialization
        if self.override_credential_file:
            self.credentials = get_credential(host, self.override_config_file)

        # catalog and file store initialization
        if self.catalog:
            del self.catalog
        self.catalog = ErmrestCatalog(protocol,
                                      host,
                                      catalog_id,
                                      self.credentials,
                                      session_config=session_config)
        if self.store:
            del self.store
        self.store = HatracStore(protocol,
                                 host,
                                 self.credentials,
                                 session_config=session_config)

        # transfer state initialization
        self.loadTransferState()
        """
         Configuration initialization - this is a bit complex because we allow for:
             1. Run-time overriding of the config file location.
             2. Sub-classes of this class to bundle their own default configuration files in an arbitrary location.
             3. The updating of already deployed configuration files if bundled internal defaults are newer.             
        """
        config_file = self.override_config_file if self.override_config_file else None
        # 1. If we don't already have a valid (i.e., overridden) path to a config file...
        if not (config_file and os.path.isfile(config_file)):
            # 2. Get the currently deployed config file path, which could possibly be overridden by subclass
            config_file = self.getDeployedConfigFilePath()
            # 3. If the deployed default path is not valid, OR, it is valid AND is older than the bundled default
            if (not (config_file and os.path.isfile(config_file))
                    or self.isFileNewer(self.getDefaultConfigFilePath(),
                                        self.getDeployedConfigFilePath())):
                # 4. If we can locate a bundled default config file,
                if os.path.isfile(self.getDefaultConfigFilePath()):
                    # 4.1 Copy the bundled default config file to the deployment-specific config path
                    copy_config(self.getDefaultConfigFilePath(), config_file)
                else:
                    # 4.2 Otherwise, fallback to writing a failsafe default based on internal hardcoded settings
                    write_config(config_file, DefaultConfig)
        # 5. Finally, read the resolved configuration file into a config object
        self._update_internal_config(read_config(config_file))

    def _update_internal_config(self, config):
        """This updates the internal state of the uploader based on the config.
        """
        self.config = config
        # uploader initialization from configuration
        self.asset_mappings = self.config.get('asset_mappings', [])
        mu.add_types(self.config.get('mime_overrides'))

    def cancel(self):
        self.cancelled = True

    def reset(self):
        self.metadata.clear()
        self.file_list.clear()
        self.file_status.clear()
        self.skipped_files.clear()
        self.cancelled = False

    def cleanup(self):
        self.reset()
        self.config = None
        self.credentials = None
        self.cleanupTransferState()

    def setServer(self, server):
        cleanup = self.server != server
        self.server = server
        self.initialize(cleanup)

    def setCredentials(self, credentials):
        host = self.server['host']
        self.credentials = credentials
        self.catalog.set_credentials(self.credentials, host)
        self.store.set_credentials(self.credentials, host)

    @classmethod
    def getDefaultServer(cls):
        servers = cls.getServers()
        for server in servers:
            lower = {k.lower(): v for k, v in server.items()}
            if lower.get("default", False):
                return server
        return servers[0] if len(servers) else {}

    @classmethod
    def getServers(cls):
        """
        This method must be implemented by subclasses.
        """
        raise NotImplementedError(
            "This method must be implemented by a subclass.")

    @classmethod
    def getVersion(cls):
        """
        This method must be implemented by subclasses.
        """
        raise NotImplementedError(
            "This method must be implemented by a subclass.")

    @classmethod
    def getConfigPath(cls):
        """
        This method must be implemented by subclasses.
        """
        raise NotImplementedError(
            "This method must be implemented by a subclass.")

    @classmethod
    def getDeployedConfigPath(cls):
        return os.path.expanduser(os.path.normpath(cls.getConfigPath()))

    def getVersionCompatibility(self):
        return self.config.get("version_compatibility", list())

    def isVersionCompatible(self):
        compatibility = self.getVersionCompatibility()
        if len(compatibility) > 0:
            return vu.is_compatible(self.getVersion(), compatibility)
        else:
            return True

    @classmethod
    def getFileDisplayName(cls, file_path, asset_mapping=None):
        return os.path.basename(file_path)

    @staticmethod
    def isFileNewer(src, dst):
        if not (os.path.isfile(src) and os.path.isfile(dst)):
            return False

        # This comparison wont work with PyInstaller single-file bundles because the bundle is extracted to a temp dir
        # and every timestamp for every file in the bundle is reset to the bundle extraction/creation time.
        if getattr(sys, 'frozen', False):
            prefix = os.path.sep + "_MEI"
            if prefix in src:
                return False

        src_mtime = os.path.getmtime(os.path.abspath(src))
        dst_mtime = os.path.getmtime(os.path.abspath(dst))
        return src_mtime > dst_mtime

    @staticmethod
    def getFileSize(file_path):
        return os.path.getsize(file_path)

    @staticmethod
    def guessContentType(file_path):
        return mu.guess_content_type(file_path)

    @staticmethod
    def getFileHashes(file_path, hashes=frozenset(['md5'])):
        return hu.compute_file_hashes(file_path, hashes)

    @staticmethod
    def getCatalogTable(asset_mapping, metadata_dict=None):
        schema_name, table_name = asset_mapping.get('target_table',
                                                    [None, None])
        if not (schema_name and table_name):
            metadata_dict_lower = {
                k.lower(): v
                for k, v in metadata_dict.items()
            }
            schema_name = metadata_dict_lower.get("schema")
            table_name = metadata_dict_lower.get("table")
        if not (schema_name and table_name):
            raise ValueError(
                "Unable to determine target catalog table for asset type.")
        return '%s:%s' % (urlquote(schema_name), urlquote(table_name))

    @staticmethod
    def interpolateDict(src, dst, allowNone=False):
        if not (isinstance(src, dict) and isinstance(dst, dict)):
            raise ValueError(
                "Invalid input parameter type(s): (src = %s, dst = %s), expected (dict, dict)"
                % (type(src).__name__, type(dst).__name__))

        dst = dst.copy()
        # prune None values from the src, we don't want those to be replaced with the string 'None' in the dest
        empty = [k for k, v in src.items() if v is None]
        for k in empty:
            del src[k]
        # perform the string replacement for the values in the destination dict
        for k, v in dst.items():
            try:
                value = v.format(**src)
            except KeyError:
                value = v
                if value:
                    if value.startswith('{') and value.endswith('}'):
                        value = None
            dst.update({k: value})
        # remove all None valued entries in the dest, if disallowed
        if not allowNone:
            empty = [k for k, v in dst.items() if v is None]
            for k in empty:
                del dst[k]

        return dst

    @staticmethod
    def pruneDict(src, dst, stringify=True):
        dst = dst.copy()
        for k in dst.keys():
            value = src.get(k)
            dst[k] = str(value) if (stringify and value is not None) else value
        return dst

    def getCurrentConfigFilePath(self):
        return self.override_config_file if self.override_config_file else self.getDeployedConfigFilePath(
        )

    def getDefaultConfigFilePath(self):
        return os.path.normpath(
            resource_path(os.path.join("conf", self.DefaultConfigFileName)))

    def getDeployedConfigFilePath(self):
        return os.path.join(self.getDeployedConfigPath(),
                            self.server.get('host', ''),
                            self.DefaultConfigFileName)

    def getDeployedTransferStateFilePath(self):
        return os.path.join(self.getDeployedConfigPath(),
                            self.server.get('host', ''),
                            self.DefaultTransferStateFileName)

    def getRemoteConfig(self):
        catalog_config = CatalogConfig.fromcatalog(self.catalog)
        return catalog_config.annotation_obj(
            "tag:isrd.isi.edu,2017:bulk-upload")

    def getUpdatedConfig(self):
        # if we are using an overridden config file, skip the update check
        if self.override_config_file:
            return

        logging.info("Checking for updated configuration...")
        remote_config = self.getRemoteConfig()
        if not remote_config:
            logging.info(
                "Remote configuration not present, using default local configuration file."
            )
            return

        deployed_config_file_path = self.getDeployedConfigFilePath()
        if os.path.isfile(deployed_config_file_path):
            current_md5 = hu.compute_file_hashes(deployed_config_file_path,
                                                 hashes=['md5'])['md5'][0]
        else:
            logging.info("Local config not found.")
            current_md5 = None
        tempdir = tempfile.mkdtemp(prefix="deriva_upload_")
        if os.path.exists(tempdir):
            updated_config_path = os.path.abspath(
                os.path.join(tempdir, DerivaUpload.DefaultConfigFileName))
            with io.open(updated_config_path,
                         'w',
                         newline='\n',
                         encoding='utf-8') as config:
                config.write(
                    json.dumps(remote_config,
                               ensure_ascii=False,
                               sort_keys=True,
                               separators=(',', ': '),
                               indent=2))
            new_md5 = hu.compute_file_hashes(updated_config_path,
                                             hashes=['md5'])['md5'][0]
            if current_md5 != new_md5:
                logging.info("Updated configuration found.")
                config = read_config(updated_config_path)
                self._update_internal_config(config)
            else:
                logging.info("Configuration is up-to-date.")
                config = None
            shutil.rmtree(tempdir, ignore_errors=True)

            return config

    def getFileStatusAsArray(self):
        result = list()
        for key in self.file_status.keys():
            item = {"File": key}
            item.update(self.file_status[key])
            result.append(item)
        return result

    def validateFile(self, root, path, name):
        file_path = os.path.normpath(os.path.join(path, name))
        asset_group, asset_mapping, groupdict = self.getAssetMapping(file_path)
        if not asset_mapping:
            return None

        return asset_group, asset_mapping, groupdict, file_path

    def scanDirectory(self, root, abort_on_invalid_input=False):
        """

        :param root:
        :param abort_on_invalid_input:
        :return:
        """
        root = os.path.abspath(root)
        if not os.path.isdir(root):
            raise ValueError("Invalid directory specified: [%s]" % root)

        logging.info("Scanning files in directory [%s]..." % root)
        file_list = OrderedDict()
        for path, dirs, files in walk(root):
            for file_name in files:
                file_path = os.path.normpath(os.path.join(path, file_name))
                file_entry = self.validateFile(root, path, file_name)
                if not file_entry:
                    logging.info(
                        "Skipping file: [%s] -- Invalid file type or directory location."
                        % file_path)
                    self.skipped_files.add(file_path)
                    if abort_on_invalid_input:
                        raise ValueError("Invalid input detected, aborting.")
                else:
                    asset_group = file_entry[0]
                    group_list = file_list.get(asset_group, [])
                    group_list.append(file_entry)
                    file_list[asset_group] = group_list

        # make sure that file entries in both self.file_list and self.file_status are ordered by the declared order of
        # the asset_mapping for the file
        for group in sorted(file_list.keys()):
            self.file_list[group] = file_list[group]
            for file_entry in file_list[group]:
                file_path = file_entry[3]
                logging.info("Including file: [%s]." % file_path)
                status = self.getTransferStateStatus(file_path)
                if status:
                    self.file_status[file_path] = FileUploadState(
                        UploadState.Paused, status)._asdict()
                else:
                    self.file_status[file_path] = FileUploadState(
                        UploadState.Pending, "Pending")._asdict()

    def getAssetMapping(self, file_path):
        """
        :param file_path:
        :return:
        """
        asset_group = -1
        for asset_type in self.asset_mappings:
            asset_group += 1
            groupdict = dict()
            dir_pattern = asset_type.get('dir_pattern', '')
            ext_pattern = asset_type.get('ext_pattern', '')
            file_pattern = asset_type.get('file_pattern', '')
            path = file_path.replace("\\", "/")
            if dir_pattern:
                match = re.search(dir_pattern, path)
                if not match:
                    logging.debug(
                        "The dir_pattern \"%s\" failed to match the input path [%s]"
                        % (dir_pattern, path))
                    continue
                groupdict.update(match.groupdict())
            if ext_pattern:
                match = re.search(ext_pattern, path, re.IGNORECASE)
                if not match:
                    logging.debug(
                        "The ext_pattern \"%s\" failed to match the input path [%s]"
                        % (ext_pattern, path))
                    continue
                groupdict.update(match.groupdict())
            if file_pattern:
                match = re.search(file_pattern, path)
                if not match:
                    logging.debug(
                        "The file_pattern \"%s\" failed to match the input path [%s]"
                        % (file_pattern, path))
                    continue
                groupdict.update(match.groupdict())

            return asset_group, asset_type, groupdict

        return None, None, None

    def uploadFiles(self, status_callback=None, file_callback=None):
        for group, assets in self.file_list.items():
            for asset_group_num, asset_mapping, groupdict, file_path in assets:
                if self.cancelled:
                    self.file_status[file_path] = FileUploadState(
                        UploadState.Cancelled, "Cancelled by user")._asdict()
                    continue
                try:
                    self.file_status[file_path] = FileUploadState(
                        UploadState.Running, "In-progress")._asdict()
                    if status_callback:
                        status_callback()
                    self.uploadFile(file_path, asset_mapping, groupdict,
                                    file_callback)
                    self.file_status[file_path] = FileUploadState(
                        UploadState.Success, "Complete")._asdict()
                except HatracJobPaused:
                    status = self.getTransferStateStatus(file_path)
                    if status:
                        self.file_status[file_path] = FileUploadState(
                            UploadState.Paused,
                            "Paused: %s" % status)._asdict()
                    continue
                except HatracJobTimeout:
                    status = self.getTransferStateStatus(file_path)
                    if status:
                        self.file_status[file_path] = FileUploadState(
                            UploadState.Timeout, "Timeout")._asdict()
                    continue
                except HatracJobAborted:
                    self.file_status[file_path] = FileUploadState(
                        UploadState.Aborted, "Aborted by user")._asdict()
                except:
                    (etype, value, traceback) = sys.exc_info()
                    self.file_status[file_path] = FileUploadState(
                        UploadState.Failed, format_exception(value))._asdict()
                self.delTransferState(file_path)
                if status_callback:
                    status_callback()

        failed_uploads = dict()
        for key, value in self.file_status.items():
            if (value["State"]
                    == UploadState.Failed) or (value["State"]
                                               == UploadState.Timeout):
                failed_uploads[key] = value["Status"]

        if self.skipped_files:
            logging.warning(
                "The following file(s) were skipped because they did not satisfy the matching criteria "
                "of the configuration:\n\n%s\n" %
                '\n'.join(sorted(self.skipped_files)))

        if failed_uploads:
            logging.warning(
                "The following file(s) failed to upload due to errors:\n\n%s\n"
                % '\n'.join([
                    "%s -- %s" % (key, failed_uploads[key])
                    for key in sorted(failed_uploads.keys())
                ]))
            raise RuntimeError(
                "One or more file(s) failed to upload due to errors.")

    def uploadFile(self,
                   file_path,
                   asset_mapping,
                   match_groupdict,
                   callback=None):
        """
        Primary API subclass function.
        :param file_path:
        :param asset_mapping:
        :param match_groupdict:
        :param callback:
        :return:
        """
        logging.info("Processing file: [%s]" % file_path)

        if asset_mapping.get("asset_type", "file") == "table":
            self._uploadTable(file_path, asset_mapping, match_groupdict)
        else:
            self._uploadAsset(file_path, asset_mapping, match_groupdict,
                              callback)

    def _uploadAsset(self,
                     file_path,
                     asset_mapping,
                     match_groupdict,
                     callback=None):

        # 1. Populate metadata by querying the catalog
        self._queryFileMetadata(file_path, asset_mapping, match_groupdict)

        # 2. If "create_record_before_upload" specified in asset_mapping, check for an existing record, creating a new
        #    one if necessary. Otherwise delay this logic until after the file upload.
        record = None
        if stob(asset_mapping.get("create_record_before_upload", False)):
            record = self._getFileRecord(asset_mapping)

        # 3. Perform the Hatrac upload
        self._getFileHatracMetadata(asset_mapping)
        hatrac_options = asset_mapping.get("hatrac_options", {})
        versioned_uri = \
            self._hatracUpload(self.metadata["URI"],
                               file_path,
                               md5=self.metadata.get("md5_base64"),
                               sha256=self.metadata.get("sha256_base64"),
                               content_type=self.guessContentType(file_path),
                               content_disposition=self.metadata.get("content-disposition"),
                               chunked=True,
                               create_parents=stob(hatrac_options.get("create_parents", True)),
                               allow_versioning=stob(hatrac_options.get("allow_versioning", True)),
                               callback=callback)
        logging.debug("Hatrac upload successful. Result object URI: %s" %
                      versioned_uri)
        if stob(hatrac_options.get("versioned_uris", True)):
            self.metadata["URI"] = versioned_uri
        else:
            self.metadata["URI"] = versioned_uri.rsplit(":")[0]
        self.metadata["URI_urlencoded"] = urlquote(self.metadata["URI"])

        # 3. Check for an existing record and create a new one if necessary
        if not record:
            record = self._getFileRecord(asset_mapping)

        # 4. Update an existing record, if necessary
        column_map = asset_mapping.get("column_map", {})
        updated_record = self.interpolateDict(self.metadata, column_map)
        if updated_record != record:
            logging.info("Updating catalog for file [%s]" %
                         self.getFileDisplayName(file_path))
            self._catalogRecordUpdate(self.metadata['target_table'], record,
                                      updated_record)

    def _uploadTable(self,
                     file_path,
                     asset_mapping,
                     match_groupdict,
                     callback=None):
        if self.cancelled:
            return None

        self._initFileMetadata(file_path, asset_mapping, match_groupdict)
        try:
            default_columns = asset_mapping.get("default_columns")
            if not default_columns:
                default_columns = self.catalog.getDefaultColumns(
                    {}, self.metadata['target_table'])
            default_param = (
                '?defaults=%s' %
                ','.join(default_columns)) if len(default_columns) > 0 else ''
            file_ext = self.metadata['file_ext']
            if file_ext == 'csv':
                headers = {'content-type': 'text/csv'}
            elif file_ext == 'json':
                headers = {'content-type': 'application/json'}
            else:
                raise CatalogCreateError(
                    "Unsupported file type for catalog bulk upload: %s" %
                    file_ext)
            with open(file_path, "rb") as fp:
                result = self.catalog.post(
                    '/entity/%s%s' %
                    (self.metadata['target_table'], default_param),
                    fp,
                    headers=headers)
                return result
        except:
            (etype, value, traceback) = sys.exc_info()
            raise CatalogCreateError(format_exception(value))

    def _getFileRecord(self, asset_mapping):
        """
        Helper function that queries the catalog to get a record linked to the asset, or create it if it doesn't exist.
        :return: the file record
        """
        column_map = asset_mapping.get("column_map", {})
        rqt = asset_mapping['record_query_template']
        try:
            path = rqt.format(**self.metadata)
        except KeyError as e:
            raise ConfigurationError(
                "Record query template substitution error: %s" %
                format_exception(e))
        result = self.catalog.get(path).json()
        if result:
            self._updateFileMetadata(result[0])
            return self.pruneDict(result[0], column_map)
        else:
            row = self.interpolateDict(self.metadata, column_map)
            result = self._catalogRecordCreate(self.metadata['target_table'],
                                               row)
            if result:
                self._updateFileMetadata(result[0])
            return self.interpolateDict(self.metadata,
                                        column_map,
                                        allowNone=True)

    def _urlEncodeMetadata(self, safe_overrides=None):
        urlencoded = dict()
        if not safe_overrides:
            safe_overrides = dict()
        for k, v in self.metadata.items():
            if k.endswith("_urlencoded"):
                continue
            urlencoded[k + "_urlencoded"] = urlquote(str(v),
                                                     safe_overrides.get(k, ""))
        self._updateFileMetadata(urlencoded)

    def _initFileMetadata(self, file_path, asset_mapping, match_groupdict):
        self.metadata.clear()
        self._updateFileMetadata(match_groupdict)

        self.metadata['target_table'] = self.getCatalogTable(
            asset_mapping, match_groupdict)
        self.metadata["file_name"] = self.getFileDisplayName(file_path)
        self.metadata["file_size"] = self.getFileSize(file_path)

        self._urlEncodeMetadata(
            asset_mapping.get("url_encoding_safe_overrides"))

    def _updateFileMetadata(self, src, strict=False):
        if not (isinstance(src, dict)):
            ValueError(
                "Invalid input parameter type(s): (src = %s), expected (dict)"
                % type(src).__name__)
        if strict:
            for k in src.keys():
                if k in UploadMetadataReservedKeyNames:
                    logging.warning(
                        "Context metadata update specified reserved key name [%s], "
                        "ignoring value: %s " % (k, src[k]))
                    del src[k]
        self.metadata.update(src)

    def _queryFileMetadata(self, file_path, asset_mapping, match_groupdict):
        """
        Helper function that queries the catalog to get required metadata for a given file/asset
        """
        file_name = self.getFileDisplayName(file_path)
        logging.info("Computing metadata for file: [%s]." % file_name)
        self._initFileMetadata(file_path, asset_mapping, match_groupdict)

        logging.info("Computing checksums for file: [%s]. Please wait..." %
                     file_name)
        hashes = self.getFileHashes(
            file_path, asset_mapping.get('checksum_types', ['md5', 'sha256']))
        for alg, checksum in hashes.items():
            alg = alg.lower()
            self.metadata[alg] = checksum[0]
            self.metadata[alg + "_base64"] = checksum[1]

        for uri in asset_mapping.get("metadata_query_templates", []):
            try:
                path = uri.format(**self.metadata)
            except KeyError as e:
                raise RuntimeError(
                    "Metadata query template substitution error: %s" %
                    format_exception(e))
            result = self.catalog.get(path).json()
            if result:
                self._updateFileMetadata(result[0], True)
                self._urlEncodeMetadata(
                    asset_mapping.get("url_encoding_safe_overrides"))
            else:
                raise RuntimeError(
                    "Metadata query did not return any results: %s" % path)

        self._getFileExtensionMetadata(self.metadata.get("file_ext"))

        for k, v in asset_mapping.get("column_value_templates", {}).items():
            try:
                self.metadata[k] = v.format(**self.metadata)
            except KeyError as e:
                logging.warning(
                    "Column value template substitution error: %s" %
                    format_exception(e))
                continue
        self._urlEncodeMetadata(
            asset_mapping.get("url_encoding_safe_overrides"))

    def _getFileExtensionMetadata(self, ext):
        ext_map = self.config.get("file_ext_mappings", {})
        entry = ext_map.get(ext)
        if entry:
            self._updateFileMetadata(entry)

    def _getFileHatracMetadata(self, asset_mapping):
        try:
            hatrac_templates = asset_mapping["hatrac_templates"]
            # URI is required
            self.metadata["URI"] = hatrac_templates["hatrac_uri"].format(
                **self.metadata)
            # overridden content-disposition is optional
            content_disposition = hatrac_templates.get("content-disposition")
            self.metadata["content-disposition"] = \
                None if not content_disposition else content_disposition.format(**self.metadata)
            self._urlEncodeMetadata(
                asset_mapping.get("url_encoding_safe_overrides"))
        except KeyError as e:
            raise ConfigurationError("Hatrac template substitution error: %s" %
                                     format_exception(e))

    def _hatracUpload(self,
                      uri,
                      file_path,
                      md5=None,
                      sha256=None,
                      content_type=None,
                      content_disposition=None,
                      chunked=True,
                      create_parents=True,
                      allow_versioning=True,
                      callback=None):

        # check if there is already an in-progress transfer for this file,
        # and if so, that the local file has not been modified since the original upload job was created
        can_resume = False
        transfer_state = self.getTransferState(file_path)
        if transfer_state:
            content_md5 = transfer_state.get("content-md5")
            content_sha256 = transfer_state.get("content-sha256")
            if content_md5 or content_sha256:
                if (md5 == content_md5) or (sha256 == content_sha256):
                    can_resume = True

        if transfer_state and can_resume:
            logging.info(
                "Resuming upload (%s) of file: [%s] to host %s. Please wait..."
                % (self.getTransferStateStatus(file_path), file_path,
                   transfer_state.get("host")))
            path = transfer_state["target"]
            job_id = transfer_state['url'].rsplit("/", 1)[1]
            if not (transfer_state["total"] == transfer_state["completed"]):
                self.store.put_obj_chunked(
                    path,
                    file_path,
                    job_id,
                    callback=callback,
                    start_chunk=transfer_state["completed"])
            return self.store.finalize_upload_job(path, job_id)
        else:
            logging.info("Uploading file: [%s] to host %s. Please wait..." %
                         (self.getFileDisplayName(file_path), self.server_url))
            return self.store.put_loc(uri,
                                      file_path,
                                      md5=md5,
                                      sha256=sha256,
                                      content_type=content_type,
                                      content_disposition=content_disposition,
                                      chunked=chunked,
                                      create_parents=create_parents,
                                      allow_versioning=allow_versioning,
                                      callback=callback)

    def _catalogRecordCreate(self, catalog_table, row, default_columns=None):
        """

        :param catalog_table:
        :param row:
        :param default_columns:
        :return:
        """
        if self.cancelled:
            return None

        try:
            missing = self.catalog.validateRowColumns(row, catalog_table)
            if missing:
                raise CatalogCreateError(
                    "Unable to update catalog entry because one or more specified columns do not exist in the "
                    "target table: [%s]" % ','.join(missing))
            if not default_columns:
                default_columns = self.catalog.getDefaultColumns(
                    row, catalog_table)
            default_param = (
                '?defaults=%s' %
                ','.join(default_columns)) if len(default_columns) > 0 else ''
            # for default in default_columns:
            #    row[default] = None
            create_uri = '/entity/%s%s' % (catalog_table, default_param)
            logging.debug(
                "Attempting catalog record create [%s] with data: %s" %
                (create_uri, json.dumps(row)))
            return self.catalog.post(create_uri, json=[row]).json()
        except:
            (etype, value, traceback) = sys.exc_info()
            raise CatalogCreateError(format_exception(value))

    def _catalogRecordUpdate(self, catalog_table, old_row, new_row):
        """

        :param catalog_table:
        :param new_row:
        :param old_row:
        :return:
        """
        if self.cancelled:
            return None

        try:
            keys = sorted(list(new_row.keys()))
            old_keys = sorted(list(old_row.keys()))
            if keys != old_keys:
                raise RuntimeError(
                    "Cannot update catalog - "
                    "new row column list and old row column list do not match: New: %s != Old: %s"
                    % (keys, old_keys))
            combined_row = {
                'o%d' % i: old_row[keys[i]]
                for i in range(len(keys))
            }
            combined_row.update(
                {'n%d' % i: new_row[keys[i]]
                 for i in range(len(keys))})
            update_uri = '/attributegroup/%s/%s;%s' % (catalog_table, ','.join(
                ["o%d:=%s" % (i, urlquote(keys[i]))
                 for i in range(len(keys))]), ','.join([
                     "n%d:=%s" % (i, urlquote(keys[i]))
                     for i in range(len(keys))
                 ]))
            logging.debug(
                "Attempting catalog record update [%s] with data: %s" %
                (update_uri, json.dumps(combined_row)))
            return self.catalog.put(update_uri, json=[combined_row]).json()
        except:
            (etype, value, traceback) = sys.exc_info()
            raise CatalogUpdateError(format_exception(value))

    def defaultFileCallback(self, **kwargs):
        completed = kwargs.get("completed")
        total = kwargs.get("total")
        file_path = kwargs.get("file_path")
        file_name = os.path.basename(file_path) if file_path else ""
        job_info = kwargs.get("job_info", {})
        job_info.update()
        if completed and total:
            file_name = " [%s]" % file_name
            job_info.update({
                "completed": completed,
                "total": total,
                "host": kwargs.get("host")
            })
            status = "Uploading file%s: %d%% complete" % (
                file_name,
                round(((float(completed) / float(total)) % 100) * 100))
            self.setTransferState(file_path, job_info)
        else:
            summary = kwargs.get("summary", "")
            file_name = "Uploaded file: [%s] " % file_name
            status = file_name  # + summary
        if status:
            # logging.debug(status)
            pass
        if self.cancelled:
            return -1

        return True

    def loadTransferState(self):
        transfer_state_file_path = self.getDeployedTransferStateFilePath()
        transfer_state_dir = os.path.dirname(transfer_state_file_path)
        try:
            if not os.path.isdir(transfer_state_dir):
                try:
                    os.makedirs(transfer_state_dir)
                except OSError as error:
                    if error.errno != errno.EEXIST:
                        raise

            if not os.path.isfile(transfer_state_file_path):
                with open(transfer_state_file_path, "w") as tsfp:
                    json.dump(self.transfer_state, tsfp)

            self.transfer_state_fp = \
                open(transfer_state_file_path, 'r+')
            self.transfer_state = json.load(self.transfer_state_fp,
                                            object_pairs_hook=OrderedDict)
        except Exception as e:
            logging.warning(
                "Unable to read transfer state file, transfer checkpointing will not be available. "
                "Error: %s" % format_exception(e))

    def getTransferState(self, file_path):
        return self.transfer_state.get(file_path)

    def setTransferState(self, file_path, transfer_state):
        self.transfer_state[file_path] = transfer_state
        self.writeTransferState()

    def delTransferState(self, file_path):
        transfer_state = self.getTransferState(file_path)
        if transfer_state:
            del self.transfer_state[file_path]
        self.writeTransferState()

    def writeTransferState(self):
        if not self.transfer_state_fp:
            return
        try:
            self.transfer_state_fp.seek(0, 0)
            self.transfer_state_fp.truncate()
            json.dump(self.transfer_state, self.transfer_state_fp, indent=2)
            self.transfer_state_fp.flush()
        except Exception as e:
            logging.warning("Unable to write transfer state file: %s" %
                            format_exception(e))

    def cleanupTransferState(self):
        if self.transfer_state_fp and not self.transfer_state_fp.closed:
            try:
                self.transfer_state_fp.flush()
                self.transfer_state_fp.close()
            except Exception as e:
                logging.warning(
                    "Unable to flush/close transfer state file: %s" %
                    format_exception(e))

    def getTransferStateStatus(self, file_path):
        transfer_state = self.getTransferState(file_path)
        if transfer_state:
            return "%d%% complete" % (round(
                ((float(transfer_state["completed"]) /
                  float(transfer_state["total"])) % 100) * 100))
        return None