def test_escape_dn(self):
        # http://www.dataflake.org/tracker/issue_00623
        dn = 'cn="Joe Miller, Sr.", ou="odds+sods <1>", dc="host;new"'
        dn_clean = 'cn=Joe Miller\\, Sr.,ou=odds\\+sods \\<1\\>,dc=host\\;new'
        self.assertEquals(escape_dn(dn), dn_clean)

        self.assertEquals(escape_dn(None), None)
    def test_escape_dn(self):
        # http://www.dataflake.org/tracker/issue_00623
        dn = 'cn="Joe Miller, Sr.", ou="odds+sods <1>", dc="host;new"'
        dn_clean = b'cn=Joe Miller\\, Sr.,ou=odds\\+sods \\<1\\>,dc=host\\;new'
        self.assertEqual(escape_dn(dn), dn_clean)

        self.assertEqual(escape_dn(None), None)
Example #3
0
    def connect(self, bind_dn=None, bind_pwd=None):
        """ initialize an ldap server connection

        This method returns an instance of the underlying `pyldap`
        connection class. It does not need to be called explicitly, all
        other operations call it implicitly.
        """
        if not self.servers:
            raise RuntimeError('No servers defined')

        if bind_dn is None:
            bind_dn = escape_dn(self._encode_incoming(self.bind_dn),
                                self.ldap_encoding)
            bind_pwd = self._encode_incoming(self.bind_pwd)
        else:
            bind_dn = escape_dn(self._encode_incoming(bind_dn),
                                self.ldap_encoding)
            bind_pwd = self._encode_incoming(bind_pwd)

        conn = self._getConnection()
        if conn is None:
            for server in self.servers.values():
                try:
                    conn = self._connect(server['url'],
                                         conn_timeout=server['conn_timeout'],
                                         op_timeout=server['op_timeout'])
                    if server.get('start_tls', None):
                        conn.start_tls_s()
                    break
                except (ldap.SERVER_DOWN, ldap.TIMEOUT, ldap.LOCAL_ERROR) as e:
                    conn = None
                    exc = e
                    continue

            if conn is None:
                msg = 'Failure connecting, last attempt: %s (%s)' % (
                    server['url'], str(exc) or 'no exception')
                self.logger().critical(msg, exc_info=1)

                if exc:
                    raise exc

            connection_cache.set(self.hash, conn)

        last_bind = getattr(conn, '_last_bind', None)
        if not last_bind or \
           last_bind[1][0] != bind_dn or \
           last_bind[1][1] != bind_pwd:
            conn.simple_bind_s(bind_dn, bind_pwd)

        return conn
    def search( self
              , base
              , scope=ldap.SCOPE_SUBTREE
              , fltr='(objectClass=*)'
              , attrs=None
              , convert_filter=True
              , bind_dn=None
              , bind_pwd=None
              , raw=False
              ):
        """ Search for entries in the database
        """
        result = {'size': 0, 'results': [], 'exception': ''}
        if convert_filter:
            fltr = self._encode_incoming(fltr)
        base = escape_dn(self._encode_incoming(base))
        connection = self.connect(bind_dn=bind_dn, bind_pwd=bind_pwd)

        try:
            res = connection.search_s(base, scope, fltr, attrs)
        except ldap.PARTIAL_RESULTS:
            res_type, res = connection.result(all=0)
        except ldap.REFERRAL, e:
            connection = self._handle_referral(e)

            try:
                res = connection.search_s(base, scope, fltr, attrs)
            except ldap.PARTIAL_RESULTS:
                res_type, res = connection.result(all=0)
Example #5
0
    def insert(self, base, rdn, attrs=None, bind_dn=None, bind_pwd=None):
        """ Insert a new record

        attrs is expected to be a mapping where the value may be a string
        or a sequence of strings.
        Multiple values may be expressed as a single string if the values
        are semicolon-delimited.
        Values can be marked as binary values, meaning they are not encoded
        as UTF-8, by appending ';binary' to the key.
        """
        self._complainIfReadOnly()
        base = escape_dn(self._encode_incoming(base), self.ldap_encoding)
        rdn = escape_dn(self._encode_incoming(rdn), self.ldap_encoding)

        dn = rdn + b',' + base
        attribute_list = []
        attrs = attrs and attrs or {}

        for attr_key, values in attrs.items():
            if attr_key.endswith(';binary'):
                is_binary = True
                attr_key = attr_key[:-7]
            else:
                is_binary = False

            if not isinstance(attr_key, six.binary_type):
                attr_key = self._encode_incoming(attr_key)

            if isinstance(values, six.string_types) and not is_binary:
                values = [x.strip() for x in values.split(';')]
            elif isinstance(values, six.binary_type) and not is_binary:
                values = [x.strip() for x in values.split(b';')]

            if values != ['']:
                if not is_binary:
                    values = [self._encode_incoming(x) for x in values]
                attribute_list.append((attr_key, values))

        try:
            connection = self.connect(bind_dn=bind_dn, bind_pwd=bind_pwd)
            connection.add_s(dn, attribute_list)
        except ldap.REFERRAL as e:
            connection = self._handle_referral(e)
            connection.add_s(dn, attribute_list)
    def connect(self, bind_dn=None, bind_pwd=None):
        """ initialize an ldap server connection 

        This method returns an instance of the underlying `python-ldap` 
        connection class. It does not need to be called explicitly, all
        other operations call it implicitly.
        """
        if len(self.servers.keys()) == 0:
            raise RuntimeError('No servers defined')

        if bind_dn is None:
            bind_dn = escape_dn(self._encode_incoming(self.bind_dn))
            bind_pwd = self._encode_incoming(self.bind_pwd)
        else:
            bind_dn = escape_dn(self._encode_incoming(bind_dn))
            bind_pwd = self._encode_incoming(bind_pwd)

        conn = self._getConnection()
        if conn is None:
            for server in self.servers.values():
                try:
                    conn = self._connect( server['url']
                                        , conn_timeout=server['conn_timeout']
                                        , op_timeout=server['op_timeout']
                                        )
                    if server.get('start_tls', None):
                        conn.start_tls_s()
                    break
                except (ldap.SERVER_DOWN, ldap.TIMEOUT, ldap.LOCAL_ERROR), e:
                    conn = None
                    continue

            if conn is None:
                exception_string = str(e or 'no exception')
                msg = 'Failure connecting, last attempt: %s (%s)' % (
                            server['url'], str(e or 'no exception'))
                self.logger().critical(msg, exc_info=1)

                if e:
                    raise e

            connection_cache.set(self.hash, conn)
Example #7
0
    def delete(self, dn, bind_dn=None, bind_pwd=None):
        """ Delete a record
        """
        self._complainIfReadOnly()

        dn = escape_dn(self._encode_incoming(dn), self.ldap_encoding)

        try:
            connection = self.connect(bind_dn=bind_dn, bind_pwd=bind_pwd)
            connection.delete_s(dn)
        except ldap.REFERRAL as e:
            connection = self._handle_referral(e)
            connection.delete_s(dn)
    def delete(self, dn, bind_dn=None, bind_pwd=None):
        """ Delete a record 
        """
        self._complainIfReadOnly()

        dn = escape_dn(self._encode_incoming(dn))

        try:
            connection = self.connect(bind_dn=bind_dn, bind_pwd=bind_pwd)
            connection.delete_s(dn)
        except ldap.REFERRAL, e:
            connection = self._handle_referral(e)
            connection.delete_s(dn)
    def insert(self, base, rdn, attrs=None, bind_dn=None, bind_pwd=None):
        """ Insert a new record 

        attrs is expected to be a mapping where the value may be a string
        or a sequence of strings. 
        Multiple values may be expressed as a single string if the values 
        are semicolon-delimited.
        Values can be marked as binary values, meaning they are not encoded
        as UTF-8, by appending ';binary' to the key.
        """
        self._complainIfReadOnly()
        base = escape_dn(self._encode_incoming(base))
        rdn = escape_dn(self._encode_incoming(rdn))

        dn = rdn + ',' + base
        attribute_list = []
        attrs = attrs and attrs or {}

        for attr_key, values in attrs.items():
            if attr_key.endswith(';binary'):
                is_binary = True
                attr_key = attr_key[:-7]
            else:
                is_binary = False

            if isinstance(values, basestring) and not is_binary:
                values = [x.strip() for x in values.split(';')]

            if values != ['']:
                if not is_binary:
                    values = [self._encode_incoming(x) for x in values]
                attribute_list.append((attr_key, values))

        try:
            connection = self.connect(bind_dn=bind_dn, bind_pwd=bind_pwd)
            connection.add_s(dn, attribute_list)
        except ldap.REFERRAL, e:
            connection = self._handle_referral(e)
            connection.add_s(dn, attribute_list)
Example #10
0
    def modify(self,
               dn,
               mod_type=None,
               attrs=None,
               bind_dn=None,
               bind_pwd=None):
        """ Modify a record
        """
        self._complainIfReadOnly()

        unescaped_dn = self._encode_incoming(dn)
        dn = escape_dn(unescaped_dn, self.ldap_encoding)
        res = self.search(base=unescaped_dn,
                          scope=ldap.SCOPE_BASE,
                          bind_dn=bind_dn,
                          bind_pwd=bind_pwd,
                          raw=True)
        attrs = attrs and attrs or {}
        cur_rec = res['results'][0]
        mod_list = []

        for key, values in list(attrs.items()):

            if key.endswith(';binary'):
                key = key[:-7]
                is_binary = True
            else:
                is_binary = False

            if not isinstance(key, six.binary_type):
                key = self._encode_incoming(key)

            if not is_binary:
                if isinstance(values, six.string_types):
                    values = [
                        self._encode_incoming(x) for x in values.split(';')
                    ]
                else:
                    values = [self._encode_incoming(x) for x in values]

            if isinstance(key, six.text_type):
                key = self._encode_incoming(key)

            if mod_type is None:
                if key not in cur_rec and values != [b'']:
                    mod_list.append((ldap.MOD_ADD, key, values))
                elif cur_rec.get(key, [b'']) != values and \
                        values not in ([b''], []):
                    mod_list.append((ldap.MOD_REPLACE, key, values))
                elif key in cur_rec and values in ([b''], []):
                    mod_list.append((ldap.MOD_DELETE, key, None))
            elif mod_type in (ldap.MOD_ADD, ldap.MOD_DELETE) and \
                    values == [b'']:
                continue
            elif mod_type == ldap.MOD_DELETE and \
                    set(values).difference(set(cur_rec.get(key, []))):
                continue
            else:
                mod_list.append((mod_type, key, values))

            attrs[key] = values

        try:
            connection = self.connect(bind_dn=bind_dn, bind_pwd=bind_pwd)

            dn_parts = str2dn(dn)
            clean_dn_parts = []
            for dn_part in dn_parts:
                for (attr_name, attr_val, flag) in dn_part:
                    if isinstance(attr_name, six.text_type):
                        attr_name = self._encode_incoming(attr_name)
                    if isinstance(attr_val, six.text_type):
                        attr_val = self._encode_incoming(attr_val)
                    clean_dn_parts.append([(attr_name, attr_val, flag)])

            rdn_attr = clean_dn_parts[0][0][0]
            raw_rdn = attrs.get(rdn_attr, '')
            if isinstance(raw_rdn, six.string_types):
                raw_rdn = [raw_rdn]
            new_rdn = raw_rdn[0]

            if new_rdn:
                rdn_value = self._encode_incoming(new_rdn)
                if rdn_value != cur_rec.get(rdn_attr)[0]:
                    clean_dn_parts[0] = [(rdn_attr, rdn_value, 1)]
                    dn_parts[0] = [(rdn_attr, raw_rdn[0], 1)]
                    raw_utf8_rdn = rdn_attr + b'=' + rdn_value
                    new_rdn = escape_dn(raw_utf8_rdn, self.ldap_encoding)
                    connection.modrdn_s(dn, new_rdn)
                    dn = dn2str(clean_dn_parts)

            if mod_list:
                connection.modify_s(dn, mod_list)
            else:
                debug_msg = 'Nothing to modify: %s' % dn
                self.logger().debug(debug_msg)

        except ldap.REFERRAL as e:
            connection = self._handle_referral(e)
            connection.modify_s(dn, mod_list)
Example #11
0
    def search(self,
               base,
               scope=ldap.SCOPE_SUBTREE,
               fltr='(objectClass=*)',
               attrs=None,
               convert_filter=True,
               bind_dn=None,
               bind_pwd=None,
               raw=False):
        """ Search for entries in the database
        """
        result = {'size': 0, 'results': [], 'exception': ''}
        if convert_filter:
            fltr = self._encode_incoming(fltr)
        base = escape_dn(self._encode_incoming(base), self.ldap_encoding)
        connection = self.connect(bind_dn=bind_dn, bind_pwd=bind_pwd)

        try:
            res = connection.search_s(base, scope, fltr, attrs)
        except ldap.PARTIAL_RESULTS:
            res_type, res = connection.result(all=0)
        except ldap.REFERRAL as e:
            connection = self._handle_referral(e)

            try:
                res = connection.search_s(base, scope, fltr, attrs)
            except ldap.PARTIAL_RESULTS:
                res_type, res = connection.result(all=0)

        for rec_dn, rec_dict in res:
            # When used against Active Directory, "rec_dict" may not be
            # be a dictionary in some cases (instead, it can be a list)
            # An example of a useless "res" entry that can be ignored
            # from AD is
            # (None, ['ldap://ForestDnsZones.PORTAL.LOCAL/DC=ForestDnsZones,\
            # DC=PORTAL,DC=LOCAL'])
            # This appears to be some sort of internal referral, but
            # we can't handle it, so we need to skip over it.
            try:
                items = list(rec_dict.items())
            except AttributeError:
                # 'items' not found on rec_dict
                continue

            if raw:
                rec_dict['dn'] = rec_dn
            else:
                for key, value in items:
                    if key == b'dn':
                        del rec_dict[key]
                    elif key.lower() not in BINARY_ATTRIBUTES:
                        if isinstance(value, (list, tuple)):
                            for i in range(len(value)):
                                value[i] = self._encode_outgoing(value[i])
                        else:
                            rec_dict[key] = self._encode_outgoing(value)

                rec_dict['dn'] = self._encode_outgoing(rec_dn)

            result['results'].append(rec_dict)
            result['size'] += 1

        return result
    def modify(self, dn, mod_type=None, attrs=None, bind_dn=None, bind_pwd=None):
        """ Modify a record 
        """
        self._complainIfReadOnly()

        unescaped_dn = self._encode_incoming(dn)
        dn = escape_dn(unescaped_dn)
        res = self.search( base=unescaped_dn
                         , scope=ldap.SCOPE_BASE
                         , bind_dn=bind_dn
                         , bind_pwd=bind_pwd
                         , raw=True
                         )
        attrs = attrs and attrs or {}
        cur_rec = res['results'][0]
        mod_list = []

        for key, values in attrs.items():

            if key.endswith(';binary'):
                key = key[:-7]
            elif isinstance(values, basestring):
                values = [self._encode_incoming(x) for x in values.split(';')]
            else:
                values = [self._encode_incoming(x) for x in values]

            if mod_type is None:
                if not cur_rec.has_key(key) and values != ['']:
                    mod_list.append((ldap.MOD_ADD, key, values))
                elif cur_rec.get(key,['']) != values and values not in ([''],[]):
                    mod_list.append((ldap.MOD_REPLACE, key, values))
                elif cur_rec.has_key(key) and values in ([''], []):
                    mod_list.append((ldap.MOD_DELETE, key, None))
            elif mod_type in (ldap.MOD_ADD, ldap.MOD_DELETE) and values == ['']:
                continue
            elif ( mod_type == ldap.MOD_DELETE and
                   set(values).difference(set(cur_rec.get(key, []))) ):
                continue
            else:
                mod_list.append((mod_type, key, values))

        try:
            connection = self.connect(bind_dn=bind_dn, bind_pwd=bind_pwd)

            dn_parts = str2dn(dn)
            rdn = dn_parts[0]
            rdn_attr = rdn[0][0]
            raw_rdn = attrs.get(rdn_attr, '')
            if isinstance(raw_rdn, basestring):
                raw_rdn = [raw_rdn]
            new_rdn = raw_rdn[0]

            if new_rdn:
                rdn_value = self._encode_incoming(new_rdn)
                if rdn_value != cur_rec.get(rdn_attr)[0]:
                    dn_parts[0] = [(rdn_attr, rdn_value, 1)]
                    raw_utf8_rdn = rdn_attr + '=' + rdn_value
                    new_rdn = escape_dn(raw_utf8_rdn)
                    connection.modrdn_s(dn, new_rdn)
                    dn = dn2str(dn_parts)

            if mod_list:
                connection.modify_s(dn, mod_list)
            else:
                debug_msg = 'Nothing to modify: %s' % dn
                self.logger().debug(debug_msg)

        except ldap.REFERRAL, e:
            connection = self._handle_referral(e)
            connection.modify_s(dn, mod_list)