Exemple #1
0
    def run(self, groupname, credopts=None, sambaopts=None, versionopts=None, H=None):

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)
        samdb = SamDB(url=H, session_info=system_session(),
                      credentials=creds, lp=lp)

        filter = ("(&(sAMAccountName=%s)(objectClass=group))" %
                  groupname)

        try:
            res = samdb.search(base=samdb.domain_dn(),
                               scope=ldb.SCOPE_SUBTREE,
                               expression=filter,
                               attrs=["dn"])
            group_dn = res[0].dn
        except IndexError:
            raise CommandError('Unable to find group "%s"' % (groupname))

        try:
            samdb.delete(group_dn)
        except Exception as e:
            # FIXME: catch more specific exception
            raise CommandError('Failed to remove group "%s"' % groupname, e)
        self.outf.write("Deleted group %s\n" % groupname)
Exemple #2
0
    def run(self, ou_dn, credopts=None, sambaopts=None, versionopts=None,
            H=None, force_subtree_delete=False):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)
        samdb = SamDB(url=H, session_info=system_session(),
                      credentials=creds, lp=lp)
        domain_dn = ldb.Dn(samdb, samdb.domain_dn())

        try:
            full_ou_dn = samdb.normalize_dn_in_domain(ou_dn)
        except Exception as e:
            raise CommandError('Invalid ou_dn "%s": %s' % (ou_dn, e))

        controls = []
        if force_subtree_delete:
            controls = ["tree_delete:1"]

        try:
            res = samdb.search(base=full_ou_dn,
                               expression="(objectclass=organizationalUnit)",
                               scope=ldb.SCOPE_BASE, attrs=[])
            if len(res) == 0:
                self.outf.write('Unable to find ou "%s"\n' % ou_dn)
                return
            samdb.delete(full_ou_dn, controls)
        except Exception as e:
            raise CommandError('Failed to delete ou "%s"' % full_ou_dn, e)

        self.outf.write('Deleted ou "%s"\n' % full_ou_dn)
Exemple #3
0
    def run(self, groupname, credopts=None, sambaopts=None, versionopts=None, H=None):

        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)
        samdb = SamDB(url=H, session_info=system_session(),
                      credentials=creds, lp=lp)

        filter = ("(&(sAMAccountName=%s)(objectClass=group))" %
                  groupname)

        try:
            res = samdb.search(base=samdb.domain_dn(),
                               scope=ldb.SCOPE_SUBTREE,
                               expression=filter,
                               attrs=["dn"])
            group_dn = res[0].dn
        except IndexError:
            raise CommandError('Unable to find group "%s"' % (groupname))

        try:
            samdb.delete(group_dn)
        except Exception as e:
            # FIXME: catch more specific exception
            raise CommandError('Failed to remove group "%s"' % groupname, e)
        self.outf.write("Deleted group %s\n" % groupname)
Exemple #4
0
    def run(self,
            computername,
            credopts=None,
            sambaopts=None,
            versionopts=None,
            H=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)

        samdb = SamDB(url=H,
                      session_info=system_session(),
                      credentials=creds,
                      lp=lp)

        samaccountname = computername
        if not computername.endswith('$'):
            samaccountname = "%s$" % computername

        filter = (
            "(&(sAMAccountName=%s)(sAMAccountType=%u))" %
            (ldb.binary_encode(samaccountname), dsdb.ATYPE_WORKSTATION_TRUST))
        try:
            res = samdb.search(base=samdb.domain_dn(),
                               scope=ldb.SCOPE_SUBTREE,
                               expression=filter,
                               attrs=["userAccountControl", "dNSHostName"])
            computer_dn = res[0].dn
            computer_ac = int(res[0]["userAccountControl"][0])
            if "dNSHostName" in res[0]:
                computer_dns_host_name = res[0]["dNSHostName"][0]
            else:
                computer_dns_host_name = None
        except IndexError:
            raise CommandError('Unable to find computer "%s"' % computername)

        computer_is_workstation = (computer_ac
                                   & dsdb.UF_WORKSTATION_TRUST_ACCOUNT)
        if not computer_is_workstation:
            raise CommandError(
                'Failed to remove computer "%s": '
                'Computer is not a workstation - removal denied' %
                computername)
        try:
            samdb.delete(computer_dn)
            if computer_dns_host_name:
                remove_dns_references(samdb,
                                      self.get_logger(),
                                      computer_dns_host_name,
                                      ignore_no_name=True)
        except Exception as e:
            raise CommandError(
                'Failed to remove computer "%s"' % samaccountname, e)
        self.outf.write("Deleted computer %s\n" % computername)
Exemple #5
0
    def test_add_replicated_objects(self):
        for o in (
            {
                'dn': "ou=%s1,%s" % (self.tag, self.base_dn),
                "objectclass": "organizationalUnit"
            },
            {
                'dn': "cn=%s2,%s" % (self.tag, self.base_dn),
                "objectclass": "user"
            },
            {
                'dn': "cn=%s3,%s" % (self.tag, self.base_dn),
                "objectclass": "group"
            },
            {
                'dn': "cn=%s4,%s" % (self.tag, self.service),
                "objectclass": "NTDSConnection",
                'enabledConnection': 'TRUE',
                'fromServer': self.base_dn,
                'options': '0'
            },
        ):
            try:
                self.samdb.add(o)
                self.fail("Failed to fail to add %s" % o['dn'])
            except ldb.LdbError as e:
                (ecode, emsg) = e.args
                if ecode != ldb.ERR_REFERRAL:
                    print(emsg)
                    self.fail("Adding %s: ldb error: %s %s, wanted referral" %
                              (o['dn'], ecode, emsg))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    try:
                        tmpdb = SamDB(address,
                                      credentials=CREDS,
                                      session_info=system_session(LP),
                                      lp=LP)
                        tmpdb.add(o)
                        tmpdb.delete(o['dn'])
                    except ldb.LdbError as e:
                        self.fail("couldn't modify referred location %s" %
                                  address)

                    if address.lower().startswith(
                            self.samdb.domain_dns_name()):
                        self.fail(
                            "referral address did not give a specific DC")
Exemple #6
0
    def run(self, psoname, H=None, credopts=None, sambaopts=None,
            versionopts=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        samdb = SamDB(url=H, session_info=system_session(),
                      credentials=creds, lp=lp)

        pso_dn = "CN=%s,%s" % (psoname, pso_container(samdb))
        # sanity-check the PSO exists
        check_pso_valid(samdb, pso_dn, psoname)

        samdb.delete(pso_dn)
        self.message("Deleted PSO %s" % psoname)
Exemple #7
0
    def run(self, psoname, H=None, credopts=None, sambaopts=None,
            versionopts=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp)

        samdb = SamDB(url=H, session_info=system_session(),
            credentials=creds, lp=lp)

        pso_dn = "CN=%s,%s" % (psoname, pso_container(samdb))
        # sanity-check the PSO exists
        check_pso_valid(samdb, pso_dn, psoname)

        samdb.delete(pso_dn)
        self.message("Deleted PSO %s" % psoname)
Exemple #8
0
    def run(self,
            contactname,
            sambaopts=None,
            credopts=None,
            versionopts=None,
            H=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)
        samdb = SamDB(url=H,
                      session_info=system_session(),
                      credentials=creds,
                      lp=lp)
        base_dn = samdb.domain_dn()
        scope = ldb.SCOPE_SUBTREE

        filter = ("(&(objectClass=contact)(name=%s))" %
                  ldb.binary_encode(contactname))

        if contactname.upper().startswith("CN="):
            # contact is specified by DN
            filter = "(objectClass=contact)"
            scope = ldb.SCOPE_BASE
            try:
                base_dn = samdb.normalize_dn_in_domain(contactname)
            except Exception as e:
                raise CommandError('Invalid dn "%s": %s' %
                                   (contactname, e))

        try:
            res = samdb.search(base=base_dn,
                               scope=scope,
                               expression=filter,
                               attrs=["dn"])
            contact_dn = res[0].dn
        except IndexError:
            raise CommandError('Unable to find contact "%s"' % (contactname))

        if len(res) > 1:
            for msg in sorted(res, key=attrgetter('dn')):
                self.outf.write("found: %s\n" % msg.dn)
            raise CommandError("Multiple results for contact '%s'\n"
                               "Please specify the contact's full DN" %
                               contactname)

        try:
            samdb.delete(contact_dn)
        except Exception as e:
            raise CommandError('Failed to remove contact "%s"' % contactname, e)
        self.outf.write("Deleted contact %s\n" % contactname)
Exemple #9
0
    def test_add_replicated_objects(self):
        for o in (
                {
                    'dn': "ou=%s1,%s" % (self.tag, self.base_dn),
                    "objectclass": "organizationalUnit"
                },
                {
                    'dn': "cn=%s2,%s" % (self.tag, self.base_dn),
                    "objectclass": "user"
                },
                {
                    'dn': "cn=%s3,%s" % (self.tag, self.base_dn),
                    "objectclass": "group"
                },
                {
                    'dn': "cn=%s4,%s" % (self.tag, self.service),
                    "objectclass": "NTDSConnection",
                    'enabledConnection': 'TRUE',
                    'fromServer': self.base_dn,
                    'options': '0'
                },
        ):
            try:
                self.samdb.add(o)
                self.fail("Failed to fail to add %s" % o['dn'])
            except ldb.LdbError as e:
                (ecode, emsg) = e.args
                if ecode != ldb.ERR_REFERRAL:
                    print(emsg)
                    self.fail("Adding %s: ldb error: %s %s, wanted referral" %
                              (o['dn'], ecode, emsg))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    try:
                        tmpdb = SamDB(address, credentials=CREDS,
                                      session_info=system_session(LP), lp=LP)
                        tmpdb.add(o)
                        tmpdb.delete(o['dn'])
                    except ldb.LdbError as e:
                        self.fail("couldn't modify referred location %s" %
                                  address)

                    if address.lower().startswith(self.samdb.domain_dns_name()):
                        self.fail("referral address did not give a specific DC")
Exemple #10
0
class LATests(samba.tests.TestCase):

    def setUp(self):
        super(LATests, self).setUp()
        self.samdb = SamDB(host, credentials=creds,
                           session_info=system_session(lp), lp=lp)

        self.base_dn = self.samdb.domain_dn()
        self.ou = "OU=la,%s" % self.base_dn
        if opts.delete_in_setup:
            try:
                self.samdb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError, e:
                print "tried deleting %s, got error %s" % (self.ou, e)
        self.samdb.add({'objectclass': 'organizationalUnit',
                        'dn': self.ou})
class LATests(samba.tests.TestCase):
    def setUp(self):
        super(LATests, self).setUp()
        self.samdb = SamDB(host,
                           credentials=creds,
                           session_info=system_session(lp),
                           lp=lp)

        self.base_dn = self.samdb.domain_dn()
        self.ou = "OU=la,%s" % self.base_dn
        if opts.delete_in_setup:
            try:
                self.samdb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError, e:
                print "tried deleting %s, got error %s" % (self.ou, e)
        self.samdb.add({'objectclass': 'organizationalUnit', 'dn': self.ou})
Exemple #12
0
    def run(self, computername, credopts=None, sambaopts=None,
            versionopts=None, H=None):
        lp = sambaopts.get_loadparm()
        creds = credopts.get_credentials(lp, fallback_machine=True)

        samdb = SamDB(url=H, session_info=system_session(),
                      credentials=creds, lp=lp)

        samaccountname = computername
        if not computername.endswith('$'):
            samaccountname = "%s$" % computername

        filter = ("(&(sAMAccountName=%s)(sAMAccountType=%u))" %
                  (ldb.binary_encode(samaccountname),
                   dsdb.ATYPE_WORKSTATION_TRUST))
        try:
            res = samdb.search(base=samdb.domain_dn(),
                               scope=ldb.SCOPE_SUBTREE,
                               expression=filter,
                               attrs=["userAccountControl", "dNSHostName"])
            computer_dn = res[0].dn
            computer_ac = int(res[0]["userAccountControl"][0])
            if "dNSHostName" in res[0]:
                computer_dns_host_name = res[0]["dNSHostName"][0]
            else:
                computer_dns_host_name = None
        except IndexError:
            raise CommandError('Unable to find computer "%s"' % computername)

        computer_is_workstation = (
            computer_ac & dsdb.UF_WORKSTATION_TRUST_ACCOUNT)
        if computer_is_workstation == False:
            raise CommandError('Failed to remove computer "%s": '
                               'Computer is not a workstation - removal denied'
                               % computername)
        try:
            samdb.delete(computer_dn)
            if computer_dns_host_name:
                remove_dns_references(
                    samdb, self.get_logger(), computer_dns_host_name,
                    ignore_no_name=True)
        except Exception as e:
            raise CommandError('Failed to remove computer "%s"' %
                               samaccountname, e)
        self.outf.write("Deleted computer %s\n" % computername)
Exemple #13
0
class DsdbTests(TestCase):
    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

        # Create a test user
        user_name = "dsdb-user-" + str(uuid.uuid4().hex[0:6])
        user_pass = samba.generate_random_password(32, 32)
        user_description = "Test user for dsdb test"

        base_dn = self.samdb.domain_dn()

        self.account_dn = "cn=" + user_name + ",cn=Users," + base_dn
        self.samdb.newuser(username=user_name,
                           password=user_pass,
                           description=user_description)
        # Cleanup (teardown)
        self.addCleanup(delete_force, self.samdb, self.account_dn)

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEquals(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, [
            "local_oid:1.3.6.1.4.1.7165.4.3.14:0",
            "local_oid:1.3.6.1.4.1.7165.4.3.25:0"
        ])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val",
                                                ldb.FLAG_MOD_REPLACE,
                                                "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(13),
                          "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 2)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(
            dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description",
                                                      version + 2)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "description"),
            version + 2)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:0" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:1" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            (errno, estr) = e.args
            if errno != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail(
                    "Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                    % e[1])

    # Allocate a unique RID for use in the objectSID tests.
    #
    def allocate_rid(self):
        self.samdb.transaction_start()
        try:
            rid = self.samdb.allocate_rid()
        except:
            self.samdb.transaction_cancel()
            raise
        self.samdb.transaction_commit()
        return str(rid)

    # Ensure that duplicate objectSID's are permitted for foreign security
    # principals.
    #
    def test_duplicate_objectSIDs_allowed_on_foreign_security_principals(self):

        #
        # We need to build a foreign security principal SID
        # i.e a  SID not in the current domain.
        #
        dom_sid = self.samdb.get_domain_sid()
        if str(dom_sid).endswith("0"):
            c = "9"
        else:
            c = "0"
        sid = str(dom_sid)[:-1] + c + "-1000"
        basedn = self.samdb.get_default_basedn()
        dn = "CN=%s,CN=ForeignSecurityPrincipals,%s" % (sid, basedn)
        self.samdb.add({"dn": dn, "objectClass": "foreignSecurityPrincipal"})

        self.samdb.delete(dn)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"
            })
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.fail("Got unexpected exception %d - %s " % (code, msg))

        # cleanup
        self.samdb.delete(dn)

    #
    # Duplicate objectSID's should not be permitted for sids in the local
    # domain. The test sequence is add an object, delete it, then attempt to
    # re-add it, this should fail with a constraint violation
    #
    def test_duplicate_objectSIDs_not_allowed_on_local_objects(self):

        dom_sid = self.samdb.get_domain_sid()
        rid = self.allocate_rid()
        sid_str = str(dom_sid) + "-" + rid
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_01"
        dn = "cn=%s,cn=Users,%s" % (cn, basedn)

        self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
        self.samdb.delete(dn)

        try:
            self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            if code != ldb.ERR_CONSTRAINT_VIOLATION:
                self.fail("Got %d - %s should have got "
                          "LDB_ERR_CONSTRAINT_VIOLATION" % (code, msg))

    def test_normalize_dn_in_domain_full(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        full_str = str(full_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_str))

    def test_normalize_dn_in_domain_part(self):
        domain_dn = self.samdb.domain_dn()

        part_str = "CN=Users"

        full_dn = ldb.Dn(self.samdb, part_str)
        full_dn.add_base(domain_dn)

        # That is, the domain DN appended
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(part_str))

    def test_normalize_dn_in_domain_full_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_dn))

    def test_normalize_dn_in_domain_part_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        # That is, the domain DN appended
        self.assertEqual(
            ldb.Dn(self.samdb,
                   str(part_dn) + "," + str(domain_dn)),
            self.samdb.normalize_dn_in_domain(part_dn))
Exemple #14
0
	# Check if dns-HOSTNAME account exists and create it if required
	try:
		dn = 'samAccountName=dns-%s,CN=Principals' % names.hostname
		msg = secretsdb.search(expression='(dn=%s)' % dn, attrs=['secret'])
		dnssecret = msg[0]['secret'][0]
	except Exception:
		print "Adding dns-%s account" % names.hostname
	
		try:
			msg = samdb.search(base=names.domaindn, scope=samba.ldb.SCOPE_DEFAULT,
				expression='(sAMAccountName=dns-%s)' % (names.hostname),
				attrs=['clearTextPassword'])
			if msg:
				print "removing sAMAccountName=dns-%s" % (names.hostname)
				dn = msg[0].dn
				samdb.delete(dn)
		except Exception:
			print "exception while removing sAMAccountName=dns-%s" % (names.hostname)
			pass

		setup_add_ldif(secretsdb, setup_path("secrets_dns.ldif"), {
			"REALM": names.realm,
			"DNSDOMAIN": names.dnsdomain,
			"DNS_KEYTAB": dns_keytab_path,
			"DNSPASS_B64": b64encode(dnspass),
			"HOSTNAME": names.hostname,
			"DNSNAME" : '%s.%s' % (
				names.netbiosname.lower(), names.dnsdomain.lower())
			})

		account_created = False
Exemple #15
0
class LATests(samba.tests.TestCase):
    def setUp(self):
        super(LATests, self).setUp()
        self.samdb = SamDB(host,
                           credentials=creds,
                           session_info=system_session(lp),
                           lp=lp)

        self.base_dn = self.samdb.domain_dn()
        self.ou = "OU=la,%s" % self.base_dn
        if opts.delete_in_setup:
            try:
                self.samdb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError as e:
                print "tried deleting %s, got error %s" % (self.ou, e)
        self.samdb.add({'objectclass': 'organizationalUnit', 'dn': self.ou})

    def tearDown(self):
        super(LATests, self).tearDown()
        if not opts.no_cleanup:
            self.samdb.delete(self.ou, ['tree_delete:1'])

    def add_object(self, cn, objectclass, more_attrs={}):
        dn = "CN=%s,%s" % (cn, self.ou)
        attrs = {'cn': cn, 'objectclass': objectclass, 'dn': dn}
        attrs.update(more_attrs)
        self.samdb.add(attrs)

        return dn

    def add_objects(self, n, objectclass, prefix=None, more_attrs={}):
        if prefix is None:
            prefix = objectclass
        dns = []
        for i in range(n):
            dns.append(
                self.add_object("%s%d" % (prefix, i + 1),
                                objectclass,
                                more_attrs=more_attrs))
        return dns

    def add_linked_attribute(self, src, dest, attr='member', controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_ADD, attr)
        self.samdb.modify(m, controls=controls)

    def remove_linked_attribute(self, src, dest, attr='member', controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_DELETE, attr)
        self.samdb.modify(m, controls=controls)

    def replace_linked_attribute(self,
                                 src,
                                 dest,
                                 attr='member',
                                 controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_REPLACE, attr)
        self.samdb.modify(m, controls=controls)

    def attr_search(self, obj, attr, scope=ldb.SCOPE_BASE, **controls):
        if opts.no_reveal_internals:
            if 'reveal_internals' in controls:
                del controls['reveal_internals']

        controls = ['%s:%d' % (k, int(v)) for k, v in controls.items()]

        res = self.samdb.search(obj,
                                scope=scope,
                                attrs=[attr],
                                controls=controls)
        return res

    def assert_links(self, obj, expected, attr, msg='', **kwargs):
        res = self.attr_search(obj, attr, **kwargs)

        if len(expected) == 0:
            if attr in res[0]:
                self.fail("found attr '%s' in %s" % (attr, res[0]))
            return

        try:
            results = list([x[attr] for x in res][0])
        except KeyError:
            self.fail("missing attr '%s' on %s" % (attr, obj))

        expected = sorted(expected)
        results = sorted(results)

        if expected != results:
            print msg
            print "expected %s" % expected
            print "received %s" % results

        self.assertEqual(results, expected)

    def assert_back_links(self, obj, expected, attr='memberOf', **kwargs):
        self.assert_links(obj,
                          expected,
                          attr=attr,
                          msg='back links do not match',
                          **kwargs)

    def assert_forward_links(self, obj, expected, attr='member', **kwargs):
        self.assert_links(obj,
                          expected,
                          attr=attr,
                          msg='forward links do not match',
                          **kwargs)

    def get_object_guid(self, dn):
        res = self.samdb.search(dn, scope=ldb.SCOPE_BASE, attrs=['objectGUID'])
        return str(misc.GUID(res[0]['objectGUID'][0]))

    def assertRaisesLdbError(self, errcode, msg, f, *args, **kwargs):
        """Assert a function raises a particular LdbError."""
        try:
            f(*args, **kwargs)
        except ldb.LdbError as (num, msg):
            if num != errcode:
                lut = {
                    v: k
                    for k, v in vars(ldb).iteritems()
                    if k.startswith('ERR_') and isinstance(v, int)
                }
                self.fail("%s, expected "
                          "LdbError %s, (%d) "
                          "got %s (%d)" %
                          (msg, lut.get(errcode), errcode, lut.get(num), num))
        else:
Exemple #16
0
class VLVTests(samba.tests.TestCase):

    def create_user(self, i, n, prefix='vlvtest', suffix='', attrs=None):
        name = "%s%d%s" % (prefix, i, suffix)
        user = {
            'cn': name,
            "objectclass": "user",
            'givenName': "abcdefghijklmnopqrstuvwxyz"[i % 26],
            "roomNumber": "%sbc" % (n - i),
            "carLicense": "后来经",
            "employeeNumber": "%s%sx" % (abs(i * (99 - i)), '\n' * (i & 255)),
            "accountExpires": "%s" % (10 ** 9 + 1000000 * i),
            "msTSExpireDate4": "19%02d0101010000.0Z" % (i % 100),
            "flags": str(i * (n - i)),
            "serialNumber": "abc %s%s%s" % ('AaBb |-/'[i & 7],
                                            ' 3z}'[i & 3],
                                            '"@'[i & 1],),
        }

        # _user_broken_attrs tests are broken due to problems outside
        # of VLV.
        _user_broken_attrs = {
            # Sort doesn't look past a NUL byte.
            "photo": "\x00%d" % (n - i),
            "audio": "%sn octet string %s%s ♫♬\x00lalala" % ('Aa'[i & 1],
                                                             chr(i & 255), i),
            "displayNamePrintable": "%d\x00%c" % (i, i & 255),
            "adminDisplayName": "%d\x00b" % (n-i),
            "title": "%d%sb" % (n - i, '\x00' * i),
            "comment": "Favourite colour is %d" % (n % (i + 1)),

            # Names that vary only in case. Windows returns
            # equivalent addresses in the order they were put
            # in ('a st', 'A st',...).
            "street": "%s st" % (chr(65 | (i & 14) | ((i & 1) * 32))),
        }

        if attrs is not None:
            user.update(attrs)

        user['dn'] = "cn=%s,%s" % (user['cn'], self.ou)

        if opts.skip_attr_regex:
            match = re.compile(opts.skip_attr_regex).search
            for k in user.keys():
                if match(k):
                    del user[k]

        self.users.append(user)
        self.ldb.add(user)
        return user

    def setUp(self):
        super(VLVTests, self).setUp()
        self.ldb = SamDB(host, credentials=creds,
                         session_info=system_session(lp), lp=lp)

        self.base_dn = self.ldb.domain_dn()
        self.ou = "ou=vlv,%s" % self.base_dn
        if opts.delete_in_setup:
            try:
                self.ldb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError as e:
                print("tried deleting %s, got error %s" % (self.ou, e))
        self.ldb.add({
            "dn": self.ou,
            "objectclass": "organizationalUnit"})

        self.users = []
        for i in range(N_ELEMENTS):
            self.create_user(i, N_ELEMENTS)

        attrs = self.users[0].keys()
        self.binary_sorted_keys = ['audio',
                                   'photo',
                                   "msTSExpireDate4",
                                   'serialNumber',
                                   "displayNamePrintable"]

        self.numeric_sorted_keys = ['flags',
                                    'accountExpires']

        self.timestamp_keys = ['msTSExpireDate4']

        self.int64_keys = set(['accountExpires'])

        self.locale_sorted_keys = [x for x in attrs if
                                   x not in (self.binary_sorted_keys +
                                             self.numeric_sorted_keys)]

        # don't try spaces, etc in cn
        self.delicate_keys = ['cn']

    def tearDown(self):
        super(VLVTests, self).tearDown()
        if not opts.delete_in_setup:
            self.ldb.delete(self.ou, ['tree_delete:1'])

    def get_full_list(self, attr, include_cn=False):
        """Fetch the whole list sorted on the attribute, using the VLV.
        This way you get a VLV cookie."""
        n_users = len(self.users)
        sort_control = "server_sort:1:0:%s" % attr
        half_n = n_users // 2
        vlv_search = "vlv:1:%d:%d:%d:0" % (half_n, half_n, half_n + 1)
        attrs = [attr]
        if include_cn:
            attrs.append('cn')
        res = self.ldb.search(self.ou,
                              scope=ldb.SCOPE_ONELEVEL,
                              attrs=attrs,
                              controls=[sort_control,
                                        vlv_search])
        if include_cn:
            full_results = [(x[attr][0], x['cn'][0]) for x in res]
        else:
            full_results = [x[attr][0].lower() for x in res]
        controls = res.controls
        return full_results, controls, sort_control

    def get_expected_order(self, attr, expression=None):
        """Fetch the whole list sorted on the attribute, using sort only."""
        sort_control = "server_sort:1:0:%s" % attr
        res = self.ldb.search(self.ou,
                              scope=ldb.SCOPE_ONELEVEL,
                              expression=expression,
                              attrs=[attr],
                              controls=[sort_control])
        results = [x[attr][0] for x in res]
        return results

    def delete_user(self, user):
        self.ldb.delete(user['dn'])
        del self.users[self.users.index(user)]

    def get_gte_tests_and_order(self, attr, expression=None):
        expected_order = self.get_expected_order(attr, expression=expression)
        gte_users = []
        if attr in self.delicate_keys:
            gte_keys = [
                '3',
                'abc',
                '¹',
                'ŋđ¼³ŧ“«đð',
                '桑巴',
            ]
        elif attr in self.timestamp_keys:
            gte_keys = [
                '18560101010000.0Z',
                '19140103010000.0Z',
                '19560101010010.0Z',
                '19700101000000.0Z',
                '19991231211234.3Z',
                '20061111211234.0Z',
                '20390901041234.0Z',
                '25560101010000.0Z',
            ]
        elif attr not in self.numeric_sorted_keys:
            gte_keys = [
                '3',
                'abc',
                ' ',
                '!@#!@#!',
                'kōkako',
                '¹',
                'ŋđ¼³ŧ“«đð',
                '\n\t\t',
                '桑巴',
                'zzzz',
            ]
            if expected_order:
                gte_keys.append(expected_order[len(expected_order) // 2] + ' tail')

        else:
            # "numeric" means positive integers
            # doesn't work with -1, 3.14, ' 3', '9' * 20
            gte_keys = ['3',
                        '1' * 10,
                        '1',
                        '9' * 7,
                        '0']

            if attr in self.int64_keys:
                gte_keys += ['3' * 12, '71' * 8]

        for i, x in enumerate(gte_keys):
            user = self.create_user(i, N_ELEMENTS,
                                    prefix='gte',
                                    attrs={attr: x})
            gte_users.append(user)

        gte_order = self.get_expected_order(attr)
        for user in gte_users:
            self.delete_user(user)

        # for sanity's sake
        expected_order_2 = self.get_expected_order(attr, expression=expression)
        self.assertEqual(expected_order, expected_order_2)

        # Map gte tests to indexes in expected order. This will break
        # if gte_order and expected_order are differently ordered (as
        # it should).
        gte_map = {}

        # index to the first one with each value
        index_map = {}
        for i, k in enumerate(expected_order):
            if k not in index_map:
                index_map[k] = i

        keys = []
        for k in gte_order:
            if k in index_map:
                i = index_map[k]
                gte_map[k] = i
                for k in keys:
                    gte_map[k] = i
                keys = []
            else:
                keys.append(k)

        for k in keys:
            gte_map[k] = len(expected_order)

        if False:
            print("gte_map:")
            for k in gte_order:
                print("   %10s => %10s" % (k, gte_map[k]))

        return gte_order, expected_order, gte_map

    def assertCorrectResults(self, results, expected_order,
                             offset, before, after):
        """A helper to calculate offsets correctly and say as much as possible
        when something goes wrong."""

        start = max(offset - before - 1, 0)
        end = offset + after
        expected_results = expected_order[start: end]

        # if it is a tuple with the cn, drop the cn
        if expected_results and isinstance(expected_results[0], tuple):
            expected_results = [x[0] for x in expected_results]

        if expected_results == results:
            return

        if expected_order is not None:
            print("expected order: %s" % expected_order[:20])
            if len(expected_order) > 20:
                print("... and %d more not shown" % (len(expected_order) - 20))

        print("offset %d before %d after %d" % (offset, before, after))
        print("start %d end %d" % (start, end))
        print("expected: %s" % expected_results)
        print("got     : %s" % results)
        self.assertEquals(expected_results, results)

    def test_server_vlv_with_cookie(self):
        attrs = [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass')]
        for attr in attrs:
            expected_order = self.get_expected_order(attr)
            sort_control = "server_sort:1:0:%s" % attr
            res = None
            n = len(self.users)
            for before in [10, 0, 3, 1, 4, 5, 2]:
                for after in [0, 3, 1, 4, 5, 2, 7]:
                    for offset in range(max(1, before - 2),
                                        min(n - after + 2, n)):
                        if res is None:
                            vlv_search = "vlv:1:%d:%d:%d:0" % (before, after,
                                                               offset)
                        else:
                            cookie = get_cookie(res.controls, n)
                            vlv_search = ("vlv:1:%d:%d:%d:%s:%s" %
                                          (before, after, offset, n,
                                           cookie))

                        res = self.ldb.search(self.ou,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        vlv_search])

                        results = [x[attr][0] for x in res]

                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def run_index_tests_with_expressions(self, expressions):
        # Here we don't test every before/after combination.
        attrs = [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass')]
        for attr in attrs:
            for expression in expressions:
                expected_order = self.get_expected_order(attr, expression)
                sort_control = "server_sort:1:0:%s" % attr
                res = None
                n = len(expected_order)
                for before in range(0, 11):
                    after = before
                    for offset in range(max(1, before - 2),
                                        min(n - after + 2, n)):
                        if res is None:
                            vlv_search = "vlv:1:%d:%d:%d:0" % (before, after,
                                                               offset)
                        else:
                            cookie = get_cookie(res.controls)
                            vlv_search = ("vlv:1:%d:%d:%d:%s:%s" %
                                          (before, after, offset, n,
                                           cookie))

                        res = self.ldb.search(self.ou,
                                              expression=expression,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        vlv_search])

                        results = [x[attr][0] for x in res]

                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def test_server_vlv_with_expression(self):
        """What happens when we run the VLV with an expression?"""
        expressions = ["(objectClass=*)",
                       "(cn=%s)" % self.users[-1]['cn'],
                       "(roomNumber=%s)" % self.users[0]['roomNumber'],
                       ]
        self.run_index_tests_with_expressions(expressions)

    def test_server_vlv_with_failing_expression(self):
        """What happens when we run the VLV on an expression that matches
        nothing?"""
        expressions = ["(samaccountname=testferf)",
                       "(cn=hefalump)",
                       ]
        self.run_index_tests_with_expressions(expressions)

    def run_gte_tests_with_expressions(self, expressions):
        # Here we don't test every before/after combination.
        attrs = [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass')]
        for expression in expressions:
            for attr in attrs:
                gte_order, expected_order, gte_map = \
                    self.get_gte_tests_and_order(attr, expression)
                # In case there is some order dependency, disorder tests
                gte_tests = gte_order[:]
                random.seed(2)
                random.shuffle(gte_tests)
                res = None
                sort_control = "server_sort:1:0:%s" % attr

                expected_order = self.get_expected_order(attr, expression)
                sort_control = "server_sort:1:0:%s" % attr
                res = None
                for before in range(0, 11):
                    after = before
                    for gte in gte_tests:
                        if res is not None:
                            cookie = get_cookie(res.controls)
                        else:
                            cookie = None
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        gte=gte,
                                                        cookie=cookie)

                        res = self.ldb.search(self.ou,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              expression=expression,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        vlv_search])

                        results = [x[attr][0] for x in res]
                        offset = gte_map.get(gte, len(expected_order))

                        # here offset is 0-based
                        start = max(offset - before, 0)
                        end = offset + 1 + after

                        expected_results = expected_order[start: end]

                        self.assertEquals(expected_results, results)

    def test_vlv_gte_with_expression(self):
        """What happens when we run the VLV with an expression?"""
        expressions = ["(objectClass=*)",
                       "(cn=%s)" % self.users[-1]['cn'],
                       "(roomNumber=%s)" % self.users[0]['roomNumber'],
                       ]
        self.run_gte_tests_with_expressions(expressions)

    def test_vlv_gte_with_failing_expression(self):
        """What happens when we run the VLV on an expression that matches
        nothing?"""
        expressions = ["(samaccountname=testferf)",
                       "(cn=hefalump)",
                       ]
        self.run_gte_tests_with_expressions(expressions)

    def test_server_vlv_with_cookie_while_adding_and_deleting(self):
        """What happens if we add or remove items in the middle of the VLV?

        Nothing. The search and the sort is not repeated, and we only
        deal with the objects originally found.
        """
        attrs = ['cn'] + [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass')]
        user_number = 0
        iteration = 0
        for attr in attrs:
            full_results, controls, sort_control = \
                            self.get_full_list(attr, True)
            original_n = len(self.users)

            expected_order = full_results
            random.seed(1)

            for before in range(0, 3) + [6, 11, 19]:
                for after in range(0, 3) + [6, 11, 19]:
                    start = max(before - 1, 1)
                    end = max(start + 4, original_n - after + 2)
                    for offset in range(start, end):
                        #if iteration > 2076:
                        #    return
                        cookie = get_cookie(controls, original_n)
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        offset=offset,
                                                        n=original_n,
                                                        cookie=cookie)

                        iteration += 1
                        res = self.ldb.search(self.ou,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        vlv_search])

                        controls = res.controls
                        results = [x[attr][0] for x in res]
                        real_offset = max(1, min(offset, len(expected_order)))

                        expected_results = []
                        skipped = 0
                        begin_offset = max(real_offset - before - 1, 0)
                        real_before = min(before, real_offset - 1)
                        real_after = min(after,
                                         len(expected_order) - real_offset)

                        for x in expected_order[begin_offset:]:
                            if x is not None:
                                expected_results.append(x[0])
                                if (len(expected_results) ==
                                    real_before + real_after + 1):
                                    break
                            else:
                                skipped += 1

                        if expected_results != results:
                            print ("attr %s before %d after %d offset %d" %
                                   (attr, before, after, offset))
                        self.assertEquals(expected_results, results)

                        n = len(self.users)
                        if random.random() < 0.1 + (n < 5) * 0.05:
                            if n == 0:
                                i = 0
                            else:
                                i = random.randrange(n)
                            user = self.create_user(i, n, suffix='-%s' %
                                                    user_number)
                            user_number += 1
                        if random.random() < 0.1  + (n > 50) * 0.02 and n:
                            index = random.randrange(n)
                            user = self.users.pop(index)

                            self.ldb.delete(user['dn'])

                            replaced = (user[attr], user['cn'])
                            if replaced in expected_order:
                                i = expected_order.index(replaced)
                                expected_order[i] = None

    def test_server_vlv_with_cookie_while_changing(self):
        """What happens if we modify items in the middle of the VLV?

        The expected behaviour (as found on Windows) is the sort is
        not repeated, but the changes in attributes are reflected.
        """
        attrs = [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass', 'cn')]
        for attr in attrs:
            n_users = len(self.users)
            expected_order = [x.upper() for x in self.get_expected_order(attr)]
            sort_control = "server_sort:1:0:%s" % attr
            res = None
            i = 0

            # First we'll fetch the whole list so we know the original
            # sort order. This is necessary because we don't know how
            # the server will order equivalent items. We are using the
            # dn as a key.
            half_n = n_users // 2
            vlv_search = "vlv:1:%d:%d:%d:0" % (half_n, half_n, half_n + 1)
            res = self.ldb.search(self.ou,
                                  scope=ldb.SCOPE_ONELEVEL,
                                  attrs=['dn', attr],
                                  controls=[sort_control, vlv_search])

            results = [x[attr][0].upper() for x in res]
            #self.assertEquals(expected_order, results)

            dn_order = [str(x['dn']) for x in res]
            values = results[:]

            for before in range(0, 3):
                for after in range(0, 3):
                    for offset in range(1 + before, n_users - after):
                        cookie = get_cookie(res.controls, len(self.users))
                        vlv_search = ("vlv:1:%d:%d:%d:%s:%s" %
                                      (before, after, offset, len(self.users),
                                       cookie))

                        res = self.ldb.search(self.ou,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=['dn', attr],
                                              controls=[sort_control,
                                                        vlv_search])

                        dn_results = [str(x['dn']) for x in res]
                        dn_expected = dn_order[offset - before - 1:
                                               offset + after]

                        self.assertEquals(dn_expected, dn_results)

                        results = [x[attr][0].upper() for x in res]

                        self.assertCorrectResults(results, values,
                                                  offset, before, after)

                        i += 1
                        if i % 3 == 2:
                            if (attr in self.locale_sorted_keys or
                                attr in self.binary_sorted_keys):
                                i1 = i % n_users
                                i2 = (i ^ 255) % n_users
                                dn1 = dn_order[i1]
                                dn2 = dn_order[i2]
                                v2 = values[i2]

                                if v2 in self.locale_sorted_keys:
                                    v2 += '-%d' % i
                                cn1 = dn1.split(',', 1)[0][3:]
                                cn2 = dn2.split(',', 1)[0][3:]

                                values[i1] = v2

                                m = ldb.Message()
                                m.dn = ldb.Dn(self.ldb, dn1)
                                m[attr] = ldb.MessageElement(v2,
                                                             ldb.FLAG_MOD_REPLACE,
                                                             attr)

                                self.ldb.modify(m)

    def test_server_vlv_fractions_with_cookie(self):
        """What happens when the count is set to an arbitrary number?

        In that case the offset and the count form a fraction, and the
        VLV should be centred at a point offset/count of the way
        through. For example, if offset is 3 and count is 6, the VLV
        should be looking around halfway. The actual algorithm is a
        bit fiddlier than that, because of the one-basedness of VLV.
        """
        attrs = [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass')]

        n_users = len(self.users)

        random.seed(4)

        for attr in attrs:
            full_results, controls, sort_control = self.get_full_list(attr)
            self.assertEqual(len(full_results), n_users)
            for before in range(0, 2):
                for after in range(0, 2):
                    for denominator in range(1, 20):
                        for offset in range(1, denominator + 3):
                            cookie = get_cookie(controls, len(self.users))
                            vlv_search = ("vlv:1:%d:%d:%d:%s:%s" %
                                          (before, after, offset,
                                           denominator,
                                           cookie))
                            try:
                                res = self.ldb.search(self.ou,
                                                      scope=ldb.SCOPE_ONELEVEL,
                                                      attrs=[attr],
                                                      controls=[sort_control,
                                                                vlv_search])
                            except ldb.LdbError as e:
                                if offset != 0:
                                    raise
                                print ("offset %d denominator %d raised error "
                                       "expected error %s\n"
                                       "(offset zero is illegal unless "
                                       "content count is zero)" %
                                       (offset, denominator, e))
                                continue

                            results = [x[attr][0].lower() for x in res]

                            if denominator == 0:
                                denominator = n_users
                                if offset == 0:
                                    offset = denominator
                            elif denominator == 1:
                                # the offset can only be 1, but the 1/1 case
                                # means something special
                                if offset == 1:
                                    real_offset = n_users
                                else:
                                    real_offset = 1
                            else:
                                if offset > denominator:
                                    offset = denominator
                                real_offset = (1 +
                                               int(round((n_users - 1) *
                                                         (offset - 1) /
                                                         (denominator - 1.0)))
                                )

                            self.assertCorrectResults(results, full_results,
                                                      real_offset, before,
                                                      after)

                            controls = res.controls
                            if False:
                                for c in list(controls):
                                    cstr = str(c)
                                    if cstr.startswith('vlv_resp'):
                                        bits = cstr.rsplit(':')
                                        print ("the answer is %s; we said %d" %
                                               (bits[2], real_offset))
                                        break

    def test_server_vlv_no_cookie(self):
        attrs = [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass')]

        for attr in attrs:
            expected_order = self.get_expected_order(attr)
            sort_control = "server_sort:1:0:%s" % attr
            for before in range(0, 5):
                for after in range(0, 7):
                    for offset in range(1 + before, len(self.users) - after):
                        res = self.ldb.search(self.ou,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        "vlv:1:%d:%d:%d:0" %
                                                        (before, after,
                                                         offset)])
                        results = [x[attr][0] for x in res]
                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def get_expected_order_showing_deleted(self, attr,
                                           expression="(|(cn=vlvtest*)(cn=vlv-deleted*))",
                                           base=None,
                                           scope=ldb.SCOPE_SUBTREE
                                           ):
        """Fetch the whole list sorted on the attribute, using sort only,
        searching in the entire tree, not just our OU. This is the
        way to find deleted objects.
        """
        if base is None:
            base = self.base_dn
        sort_control = "server_sort:1:0:%s" % attr
        controls = [sort_control, "show_deleted:1"]

        res = self.ldb.search(base,
                              scope=scope,
                              expression=expression,
                              attrs=[attr],
                              controls=controls)
        results = [x[attr][0] for x in res]
        return results

    def add_deleted_users(self, n):
        deleted_users = [self.create_user(i, n, prefix='vlv-deleted')
                         for i in range(n)]

        for user in deleted_users:
            self.delete_user(user)

    def test_server_vlv_no_cookie_show_deleted(self):
        """What do we see with the show_deleted control?"""
        attrs = ['objectGUID',
                 'cn',
                 'sAMAccountName',
                 'objectSid',
                 'name',
                 'whenChanged',
                 'usnChanged'
        ]

        # add some deleted users first, just in case there are none
        self.add_deleted_users(6)
        random.seed(22)
        expression = "(|(cn=vlvtest*)(cn=vlv-deleted*))"

        for attr in attrs:
            show_deleted_control = "show_deleted:1"
            expected_order = self.get_expected_order_showing_deleted(attr,
                                                                     expression)
            n = len(expected_order)
            sort_control = "server_sort:1:0:%s" % attr
            for before in [3, 1, 0]:
                for after in [0, 2]:
                    # don't test every position, because there could be hundreds.
                    # jump back and forth instead
                    for i in range(20):
                        offset = random.randrange(max(1, before - 2),
                                                  min(n - after + 2, n))
                        res = self.ldb.search(self.base_dn,
                                              expression=expression,
                                              scope=ldb.SCOPE_SUBTREE,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        show_deleted_control,
                                                        "vlv:1:%d:%d:%d:0" %
                                                        (before, after,
                                                         offset)
                                              ]
                        )
                        results = [x[attr][0] for x in res]
                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def test_server_vlv_no_cookie_show_deleted_only(self):
        """What do we see with the show_deleted control when we're not looking
        at any non-deleted things"""
        attrs = ['objectGUID',
                 'cn',
                 'sAMAccountName',
                 'objectSid',
                 'whenChanged',
        ]

        # add some deleted users first, just in case there are none
        self.add_deleted_users(4)
        base = 'CN=Deleted Objects,%s' % self.base_dn
        expression = "(cn=vlv-deleted*)"
        for attr in attrs:
            show_deleted_control = "show_deleted:1"
            expected_order = self.get_expected_order_showing_deleted(attr,
                                                    expression=expression,
                                                    base=base,
                                                    scope=ldb.SCOPE_ONELEVEL)
            print ("searching for attr %s amongst %d deleted objects" %
                   (attr, len(expected_order)))
            sort_control = "server_sort:1:0:%s" % attr
            step = max(len(expected_order) // 10, 1)
            for before in [3, 0]:
                for after in [0, 2]:
                    for offset in range(1 + before,
                                        len(expected_order) - after,
                                        step):
                        res = self.ldb.search(base,
                                              expression=expression,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        show_deleted_control,
                                                        "vlv:1:%d:%d:%d:0" %
                                                        (before, after,
                                                         offset)])
                        results = [x[attr][0] for x in res]
                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)



    def test_server_vlv_with_cookie_show_deleted(self):
        """What do we see with the show_deleted control?"""
        attrs = ['objectGUID',
                 'cn',
                 'sAMAccountName',
                 'objectSid',
                 'name',
                 'whenChanged',
                 'usnChanged'
        ]
        self.add_deleted_users(6)
        random.seed(23)
        for attr in attrs:
            expected_order = self.get_expected_order(attr)
            sort_control = "server_sort:1:0:%s" % attr
            res = None
            show_deleted_control = "show_deleted:1"
            expected_order = self.get_expected_order_showing_deleted(attr)
            n = len(expected_order)
            expression = "(|(cn=vlvtest*)(cn=vlv-deleted*))"
            for before in [3, 2, 1, 0]:
                after = before
                for i in range(20):
                    offset = random.randrange(max(1, before - 2),
                                              min(n - after + 2, n))
                    if res is None:
                        vlv_search = "vlv:1:%d:%d:%d:0" % (before, after,
                                                           offset)
                    else:
                        cookie = get_cookie(res.controls, n)
                        vlv_search = ("vlv:1:%d:%d:%d:%s:%s" %
                                      (before, after, offset, n,
                                       cookie))

                    res = self.ldb.search(self.base_dn,
                                          expression=expression,
                                          scope=ldb.SCOPE_SUBTREE,
                                          attrs=[attr],
                                          controls=[sort_control,
                                                    vlv_search,
                                                    show_deleted_control])

                    results = [x[attr][0] for x in res]

                    self.assertCorrectResults(results, expected_order,
                                              offset, before, after)


    def test_server_vlv_gte_with_cookie(self):
        attrs = [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass')]
        for attr in attrs:
            gte_order, expected_order, gte_map = \
                                        self.get_gte_tests_and_order(attr)
            # In case there is some order dependency, disorder tests
            gte_tests = gte_order[:]
            random.seed(1)
            random.shuffle(gte_tests)
            res = None
            sort_control = "server_sort:1:0:%s" % attr
            for before in [0, 1, 2, 4]:
                for after in [0, 1, 3, 6]:
                    for gte in gte_tests:
                        if res is not None:
                            cookie = get_cookie(res.controls, len(self.users))
                        else:
                            cookie = None
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        gte=gte,
                                                        cookie=cookie)

                        res = self.ldb.search(self.ou,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        vlv_search])

                        results = [x[attr][0] for x in res]
                        offset = gte_map.get(gte, len(expected_order))

                        # here offset is 0-based
                        start = max(offset - before, 0)
                        end = offset + 1 + after

                        expected_results = expected_order[start: end]

                        self.assertEquals(expected_results, results)

    def test_server_vlv_gte_no_cookie(self):
        attrs = [x for x in self.users[0].keys() if x not in
                 ('dn', 'objectclass')]
        iteration = 0
        for attr in attrs:
            gte_order, expected_order, gte_map = \
                                        self.get_gte_tests_and_order(attr)
            # In case there is some order dependency, disorder tests
            gte_tests = gte_order[:]
            random.seed(1)
            random.shuffle(gte_tests)

            sort_control = "server_sort:1:0:%s" % attr
            for before in [0, 1, 3]:
                for after in [0, 4]:
                    for gte in gte_tests:
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        gte=gte)

                        res = self.ldb.search(self.ou,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=[attr],
                                              controls=[sort_control,
                                                        vlv_search])
                        results = [x[attr][0] for x in res]

                        # here offset is 0-based
                        offset = gte_map.get(gte, len(expected_order))
                        start = max(offset - before, 0)
                        end = offset + after + 1
                        expected_results = expected_order[start: end]
                        iteration += 1
                        if expected_results != results:
                            middle = expected_order[len(expected_order) // 2]
                            print(expected_results, results)
                            print(middle)
                            print(expected_order)
                            print()
                            print ("\nattr %s offset %d before %d "
                                   "after %d gte %s" %
                                   (attr, offset, before, after, gte))
                        self.assertEquals(expected_results, results)

    def test_multiple_searches(self):
        """The maximum number of concurrent vlv searches per connection is
        currently set at 3. That means if you open 4 VLV searches the
        cookie on the first one should fail.
        """
        # Windows has a limit of 10 VLVs where there are low numbers
        # of objects in each search.
        attrs = ([x for x in self.users[0].keys() if x not in
                  ('dn', 'objectclass')] * 2)[:12]

        vlv_cookies = []
        for attr in attrs:
            sort_control = "server_sort:1:0:%s" % attr

            res = self.ldb.search(self.ou,
                                  scope=ldb.SCOPE_ONELEVEL,
                                  attrs=[attr],
                                  controls=[sort_control,
                                            "vlv:1:1:1:1:0"])

            cookie = get_cookie(res.controls, len(self.users))
            vlv_cookies.append(cookie)
            time.sleep(0.2)

        # now this one should fail
        self.assertRaises(ldb.LdbError,
                          self.ldb.search,
                          self.ou,
                          scope=ldb.SCOPE_ONELEVEL,
                          attrs=[attr],
                          controls=[sort_control,
                                    "vlv:1:1:1:1:0:%s" % vlv_cookies[0]])

        # and this one should succeed
        res = self.ldb.search(self.ou,
                              scope=ldb.SCOPE_ONELEVEL,
                              attrs=[attr],
                              controls=[sort_control,
                                        "vlv:1:1:1:1:0:%s" % vlv_cookies[-1]])

        # this one should fail because it is a new connection and
        # doesn't share cookies
        new_ldb = SamDB(host, credentials=creds,
                        session_info=system_session(lp), lp=lp)

        self.assertRaises(ldb.LdbError,
                          new_ldb.search, self.ou,
                          scope=ldb.SCOPE_ONELEVEL,
                          attrs=[attr],
                          controls=[sort_control,
                                    "vlv:1:1:1:1:0:%s" % vlv_cookies[-1]])

        # but now without the critical flag it just does no VLV.
        new_ldb.search(self.ou,
                       scope=ldb.SCOPE_ONELEVEL,
                       attrs=[attr],
                       controls=[sort_control,
                                 "vlv:0:1:1:1:0:%s" % vlv_cookies[-1]])
Exemple #17
0
class DsdbTests(TestCase):
    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

        # Create a test user
        user_name = "dsdb-user-" + str(uuid.uuid4().hex[0:6])
        user_pass = samba.generate_random_password(32, 32)
        user_description = "Test user for dsdb test"

        base_dn = self.samdb.domain_dn()

        self.account_dn = "cn=" + user_name + ",cn=Users," + base_dn
        self.samdb.newuser(username=user_name,
                           password=user_pass,
                           description=user_description)
        # Cleanup (teardown)
        self.addCleanup(delete_force, self.samdb, self.account_dn)

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEquals(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, [
            "local_oid:1.3.6.1.4.1.7165.4.3.14:0",
            "local_oid:1.3.6.1.4.1.7165.4.3.25:0"
        ])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val",
                                                ldb.FLAG_MOD_REPLACE,
                                                "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(13),
                          "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 2)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(
            dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description",
                                                      version + 2)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "description"),
            version + 2)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:0" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:1" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            (errno, estr) = e.args
            if errno != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail(
                    "Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                    % e[1])

    # Allocate a unique RID for use in the objectSID tests.
    #
    def allocate_rid(self):
        self.samdb.transaction_start()
        try:
            rid = self.samdb.allocate_rid()
        except:
            self.samdb.transaction_cancel()
            raise
        self.samdb.transaction_commit()
        return str(rid)

    # Ensure that duplicate objectSID's are permitted for foreign security
    # principals.
    #
    def test_duplicate_objectSIDs_allowed_on_foreign_security_principals(self):

        #
        # We need to build a foreign security principal SID
        # i.e a  SID not in the current domain.
        #
        dom_sid = self.samdb.get_domain_sid()
        if str(dom_sid).endswith("0"):
            c = "9"
        else:
            c = "0"
        sid_str = str(dom_sid)[:-1] + c + "-1000"
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        dn = "CN=%s,CN=ForeignSecurityPrincipals,%s" % (sid_str, basedn)

        #
        # First without control
        #

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"
            })
            self.fail("No exception should get ERR_OBJECT_CLASS_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_OBJECT_CLASS_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_MISSING_REQUIRED_ATT
            self.assertTrue(werr in msg, msg)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal",
                "objectSid": sid
            })
            self.fail("No exception should get ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_ILLEGAL_MOD_OPERATION
            self.assertTrue(werr in msg, msg)

        #
        # We need to use the provision control
        # in order to add foreignSecurityPrincipal
        # objects
        #

        controls = ["provision:0"]
        self.samdb.add({
            "dn": dn,
            "objectClass": "foreignSecurityPrincipal"
        },
                       controls=controls)

        self.samdb.delete(dn)

        try:
            self.samdb.add(
                {
                    "dn": dn,
                    "objectClass": "foreignSecurityPrincipal"
                },
                controls=controls)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.fail("Got unexpected exception %d - %s " % (code, msg))

        # cleanup
        self.samdb.delete(dn)

    def _test_foreignSecurityPrincipal(self, obj_class, fpo_attr):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_fpo"
        dn_str = "cn=%s,cn=Users,%s" % (cn, basedn)
        dn = ldb.Dn(self.samdb, dn_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn_str)

        self.samdb.add({"dn": dn_str, "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_INVALID_GROUP_TYPE
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_NO_SUCH_OBJECT")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_NO_SUCH_OBJECT, str(e))
            werr = "%08X" % werror.WERR_NO_SUCH_MEMBER
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 1)
        self.samdb.delete(res[0].dn)
        self.samdb.delete(dn)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

    def test_foreignSecurityPrincipal_member(self):
        return self._test_foreignSecurityPrincipal("group", "member")

    def test_foreignSecurityPrincipal_MembersForAzRole(self):
        return self._test_foreignSecurityPrincipal("msDS-AzRole",
                                                   "msDS-MembersForAzRole")

    def test_foreignSecurityPrincipal_NeverRevealGroup(self):
        return self._test_foreignSecurityPrincipal("computer",
                                                   "msDS-NeverRevealGroup")

    def test_foreignSecurityPrincipal_RevealOnDemandGroup(self):
        return self._test_foreignSecurityPrincipal("computer",
                                                   "msDS-RevealOnDemandGroup")

    def _test_fail_foreignSecurityPrincipal(self,
                                            obj_class,
                                            fpo_attr,
                                            msg_exp,
                                            lerr_exp,
                                            werr_exp,
                                            allow_reference=True):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn = self.samdb.get_default_basedn()
        cn1 = "dsdb_test_fpo1"
        dn1_str = "cn=%s,cn=Users,%s" % (cn1, basedn)
        dn1 = ldb.Dn(self.samdb, dn1_str)
        cn2 = "dsdb_test_fpo2"
        dn2_str = "cn=%s,cn=Users,%s" % (cn2, basedn)
        dn2 = ldb.Dn(self.samdb, dn2_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn1_str)
        self.addCleanup(delete_force, self.samdb, dn2_str)

        self.samdb.add({"dn": dn1_str, "objectClass": obj_class})

        self.samdb.add({"dn": dn2_str, "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("%s" % dn2, ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            if not allow_reference:
                self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            if allow_reference:
                self.fail("Should have not raised an exception: %s" % e)
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(dn2)
        self.samdb.delete(dn1)

    def test_foreignSecurityPrincipal_NonMembers(self):
        return self._test_fail_foreignSecurityPrincipal(
            "group",
            "msDS-NonMembers",
            "LDB_ERR_UNWILLING_TO_PERFORM/WERR_NOT_SUPPORTED",
            ldb.ERR_UNWILLING_TO_PERFORM,
            werror.WERR_NOT_SUPPORTED,
            allow_reference=False)

    def test_foreignSecurityPrincipal_HostServiceAccount(self):
        return self._test_fail_foreignSecurityPrincipal(
            "computer", "msDS-HostServiceAccount",
            "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
            ldb.ERR_CONSTRAINT_VIOLATION,
            werror.WERR_DS_NAME_REFERENCE_INVALID)

    def test_foreignSecurityPrincipal_manager(self):
        return self._test_fail_foreignSecurityPrincipal(
            "user", "manager",
            "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
            ldb.ERR_CONSTRAINT_VIOLATION,
            werror.WERR_DS_NAME_REFERENCE_INVALID)

    #
    # Duplicate objectSID's should not be permitted for sids in the local
    # domain. The test sequence is add an object, delete it, then attempt to
    # re-add it, this should fail with a constraint violation
    #
    def test_duplicate_objectSIDs_not_allowed_on_local_objects(self):

        dom_sid = self.samdb.get_domain_sid()
        rid = self.allocate_rid()
        sid_str = str(dom_sid) + "-" + rid
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_01"
        dn = "cn=%s,cn=Users,%s" % (cn, basedn)

        self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
        self.samdb.delete(dn)

        try:
            self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            if code != ldb.ERR_CONSTRAINT_VIOLATION:
                self.fail("Got %d - %s should have got "
                          "LDB_ERR_CONSTRAINT_VIOLATION" % (code, msg))

    def test_linked_vs_non_linked_reference(self):
        basedn = self.samdb.get_default_basedn()
        kept_dn_str = "cn=reference_kept,cn=Users,%s" % (basedn)
        removed_dn_str = "cn=reference_removed,cn=Users,%s" % (basedn)
        dom_sid = self.samdb.get_domain_sid()
        none_sid_str = str(dom_sid) + "-4294967294"
        none_guid_str = "afafafaf-fafa-afaf-fafa-afafafafafaf"

        self.addCleanup(delete_force, self.samdb, kept_dn_str)
        self.addCleanup(delete_force, self.samdb, removed_dn_str)

        self.samdb.add({"dn": kept_dn_str, "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=kept_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        kept_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        kept_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        kept_dn = res[0].dn

        self.samdb.add({"dn": removed_dn_str, "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=removed_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        removed_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        removed_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        self.samdb.delete(removed_dn_str)

        #
        # First try the linked attribute 'manager'
        # by GUID and SID
        #

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                            ldb.FLAG_MOD_ADD, "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                            ldb.FLAG_MOD_ADD, "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        #
        # Try the non-linked attribute 'assistant'
        # by GUID and SID, which should work.
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_ADD, "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_DELETE, "assistant")
        self.samdb.modify(msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_ADD, "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_DELETE, "assistant")
        self.samdb.modify(msg)

        #
        # Finally ry the non-linked attribute 'assistant'
        # but with non existing GUID, SID, DN
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("CN=NoneNone,%s" % (basedn),
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % none_sid_str,
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % none_guid_str,
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(kept_dn)

    def test_normalize_dn_in_domain_full(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        full_str = str(full_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_str))

    def test_normalize_dn_in_domain_part(self):
        domain_dn = self.samdb.domain_dn()

        part_str = "CN=Users"

        full_dn = ldb.Dn(self.samdb, part_str)
        full_dn.add_base(domain_dn)

        # That is, the domain DN appended
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(part_str))

    def test_normalize_dn_in_domain_full_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_dn))

    def test_normalize_dn_in_domain_part_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        # That is, the domain DN appended
        self.assertEqual(
            ldb.Dn(self.samdb,
                   str(part_dn) + "," + str(domain_dn)),
            self.samdb.normalize_dn_in_domain(part_dn))
Exemple #18
0
class RodcTests(samba.tests.TestCase):

    def setUp(self):
        super(RodcTests, self).setUp()
        self.samdb = SamDB(HOST, credentials=CREDS,
                           session_info=system_session(LP), lp=LP)

        self.base_dn = self.samdb.domain_dn()

        root = self.samdb.search(base='', scope=ldb.SCOPE_BASE,
                                 attrs=['dsServiceName'])
        self.service = root[0]['dsServiceName'][0]
        self.tag = uuid.uuid4().hex

    def test_add_replicated_objects(self):
        for o in (
                {
                    'dn': "ou=%s1,%s" % (self.tag, self.base_dn),
                    "objectclass": "organizationalUnit"
                },
                {
                    'dn': "cn=%s2,%s" % (self.tag, self.base_dn),
                    "objectclass": "user"
                },
                {
                    'dn': "cn=%s3,%s" % (self.tag, self.base_dn),
                    "objectclass": "group"
                },
                {
                    'dn': "cn=%s4,%s" % (self.tag, self.service),
                    "objectclass": "NTDSConnection",
                    'enabledConnection': 'TRUE',
                    'fromServer': self.base_dn,
                    'options': '0'
                },
        ):
            try:
                self.samdb.add(o)
                self.fail("Failed to fail to add %s" % o['dn'])
            except ldb.LdbError as e:
                (ecode, emsg) = e.args
                if ecode != ldb.ERR_REFERRAL:
                    print(emsg)
                    self.fail("Adding %s: ldb error: %s %s, wanted referral" %
                              (o['dn'], ecode, emsg))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    try:
                        tmpdb = SamDB(address, credentials=CREDS,
                                      session_info=system_session(LP), lp=LP)
                        tmpdb.add(o)
                        tmpdb.delete(o['dn'])
                    except ldb.LdbError as e:
                        self.fail("couldn't modify referred location %s" %
                                  address)

                    if address.lower().startswith(self.samdb.domain_dns_name()):
                        self.fail("referral address did not give a specific DC")

    def test_modify_replicated_attributes(self):
        # some timestamp ones
        dn = 'CN=Guest,CN=Users,' + self.base_dn
        value = 'hallooo'
        for attr in ['carLicense', 'middleName']:
            msg = ldb.Message()
            msg.dn = ldb.Dn(self.samdb, dn)
            msg[attr] = ldb.MessageElement(value,
                                           ldb.FLAG_MOD_REPLACE,
                                           attr)
            try:
                self.samdb.modify(msg)
                self.fail("Failed to fail to modify %s %s" % (dn, attr))
            except ldb.LdbError as e1:
                (ecode, emsg) = e1.args
                if ecode != ldb.ERR_REFERRAL:
                    self.fail("Failed to REFER when trying to modify %s %s" %
                              (dn, attr))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    try:
                        tmpdb = SamDB(address, credentials=CREDS,
                                      session_info=system_session(LP), lp=LP)
                        tmpdb.modify(msg)
                    except ldb.LdbError as e:
                        self.fail("couldn't modify referred location %s" %
                                  address)

                    if address.lower().startswith(self.samdb.domain_dns_name()):
                        self.fail("referral address did not give a specific DC")

    def test_modify_nonreplicated_attributes(self):
        # some timestamp ones
        dn = 'CN=Guest,CN=Users,' + self.base_dn
        value = '123456789'
        for attr in ['badPwdCount', 'lastLogon', 'lastLogoff']:
            m = ldb.Message()
            m.dn = ldb.Dn(self.samdb, dn)
            m[attr] = ldb.MessageElement(value,
                                         ldb.FLAG_MOD_REPLACE,
                                         attr)
            # Windows refers these ones even though they are non-replicated
            try:
                self.samdb.modify(m)
                self.fail("Failed to fail to modify %s %s" % (dn, attr))
            except ldb.LdbError as e2:
                (ecode, emsg) = e2.args
                if ecode != ldb.ERR_REFERRAL:
                    self.fail("Failed to REFER when trying to modify %s %s" %
                              (dn, attr))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    if address.lower().startswith(self.samdb.domain_dns_name()):
                        self.fail("referral address did not give a specific DC")

    def test_modify_nonreplicated_reps_attributes(self):
        # some timestamp ones
        dn = self.base_dn

        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, dn)
        attr = 'repsFrom'

        res = self.samdb.search(dn, scope=ldb.SCOPE_BASE,
                                attrs=['repsFrom'])
        rep = ndr_unpack(drsblobs.repsFromToBlob, res[0]['repsFrom'][0],
                         allow_remaining=True)
        rep.ctr.result_last_attempt = -1
        value = ndr_pack(rep)

        m[attr] = ldb.MessageElement(value,
                                     ldb.FLAG_MOD_REPLACE,
                                     attr)
        try:
            self.samdb.modify(m)
            self.fail("Failed to fail to modify %s %s" % (dn, attr))
        except ldb.LdbError as e3:
            (ecode, emsg) = e3.args
            if ecode != ldb.ERR_REFERRAL:
                self.fail("Failed to REFER when trying to modify %s %s" %
                          (dn, attr))
            else:
                m = re.search(r'(ldap://[^>]+)>', emsg)
                if m is None:
                    self.fail("referral seems not to refer to anything")
                address = m.group(1)

                if address.lower().startswith(self.samdb.domain_dns_name()):
                    self.fail("referral address did not give a specific DC")

    def test_delete_special_objects(self):
        dn = 'CN=Guest,CN=Users,' + self.base_dn
        try:
            self.samdb.delete(dn)
            self.fail("Failed to fail to delete %s" % (dn))
        except ldb.LdbError as e4:
            (ecode, emsg) = e4.args
            if ecode != ldb.ERR_REFERRAL:
                print(ecode, emsg)
                self.fail("Failed to REFER when trying to delete %s" % dn)
            else:
                m = re.search(r'(ldap://[^>]+)>', emsg)
                if m is None:
                    self.fail("referral seems not to refer to anything")
                address = m.group(1)

                if address.lower().startswith(self.samdb.domain_dns_name()):
                    self.fail("referral address did not give a specific DC")

    def test_no_delete_nonexistent_objects(self):
        dn = 'CN=does-not-exist-%s,CN=Users,%s' % (self.tag, self.base_dn)
        try:
            self.samdb.delete(dn)
            self.fail("Failed to fail to delete %s" % (dn))
        except ldb.LdbError as e5:
            (ecode, emsg) = e5.args
            if ecode != ldb.ERR_NO_SUCH_OBJECT:
                print(ecode, emsg)
                self.fail("Failed to NO_SUCH_OBJECT when trying to delete "
                          "%s (which does not exist)" % dn)
class SubtreeRenameTests(samba.tests.TestCase):

    def delete_ous(self):
        for ou in (self.ou1, self.ou2, self.ou3):
            try:
                self.samdb.delete(ou, ['tree_delete:1'])
            except ldb.LdbError as e:
                pass

    def setUp(self):
        super(SubtreeRenameTests, self).setUp()
        self.samdb = SamDB(host, credentials=creds,
                           session_info=system_session(lp), lp=lp)

        self.base_dn = self.samdb.domain_dn()
        self.ou1 = "OU=subtree1,%s" % self.base_dn
        self.ou2 = "OU=subtree2,%s" % self.base_dn
        self.ou3 = "OU=subtree3,%s" % self.base_dn
        if opts.delete_in_setup:
            self.delete_ous()
        self.samdb.add({'objectclass': 'organizationalUnit',
                        'dn': self.ou1})
        self.samdb.add({'objectclass': 'organizationalUnit',
                        'dn': self.ou2})

        debug(colour.c_REV_RED(self.id()))

    def tearDown(self):
        super(SubtreeRenameTests, self).tearDown()
        if not opts.no_cleanup:
            self.delete_ous()

    def add_object(self, cn, objectclass, ou=None, more_attrs={}):
        dn = "CN=%s,%s" % (cn, ou)
        attrs = {'cn': cn,
                 'objectclass': objectclass,
                 'dn': dn}
        attrs.update(more_attrs)
        self.samdb.add(attrs)

        return dn

    def add_objects(self, n, objectclass, prefix=None, ou=None, more_attrs={}):
        if prefix is None:
            prefix = objectclass
        dns = []
        for i in range(n):
            dns.append(self.add_object("%s%d" % (prefix, i + 1),
                                       objectclass,
                                       more_attrs=more_attrs,
                                       ou=ou))
        return dns

    def add_linked_attribute(self, src, dest, attr='member',
                             controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_ADD, attr)
        self.samdb.modify(m, controls=controls)

    def remove_linked_attribute(self, src, dest, attr='member',
                                controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_DELETE, attr)
        self.samdb.modify(m, controls=controls)

    def add_binary_link(self, src, dest, binary,
                        attr='msDS-RevealedUsers',
                        controls=None):
        b = hexlify(str(binary).encode('utf-8')).decode('utf-8').upper()
        dest = 'B:%d:%s:%s' % (len(b), b, dest)
        self.add_linked_attribute(src, dest, attr, controls)
        return dest

    def remove_binary_link(self, src, dest, binary,
                           attr='msDS-RevealedUsers',
                           controls=None):
        b = str(binary).encode('utf-8')
        dest = 'B:%s:%s' % (hexlify(b), dest)
        self.remove_linked_attribute(src, dest, attr, controls)

    def replace_linked_attribute(self, src, dest, attr='member',
                                 controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_REPLACE, attr)
        self.samdb.modify(m, controls=controls)

    def attr_search(self, obj, attr, scope=ldb.SCOPE_BASE, **controls):

        controls = ['%s:%d' % (k, int(v)) for k, v in controls.items()]

        res = self.samdb.search(obj,
                                scope=scope,
                                attrs=[attr],
                                controls=controls)
        return res

    def assert_links(self, obj, expected, attr, msg='', **kwargs):
        res = self.attr_search(obj, attr, **kwargs)

        if len(expected) == 0:
            if attr in res[0]:
                self.fail("found attr '%s' in %s" % (attr, res[0]))
            return

        try:
            results = [str(x) for x in res[0][attr]]
        except KeyError:
            self.fail("missing attr '%s' on %s" % (attr, obj))

        expected = sorted(expected)
        results = sorted(results)

        if expected != results:
            debug(msg)
            debug("expected %s" % expected)
            debug("received %s" % results)
            debug("missing    %s" % (sorted(set(expected) - set(results))))
            debug("unexpected %s" % (sorted(set(results) - set(expected))))


        self.assertEqual(results, expected)

    def assert_back_links(self, obj, expected, attr='memberOf', **kwargs):
        self.assert_links(obj, expected, attr=attr,
                          msg='%s back links do not match for %s' %
                          (attr, obj),
                          **kwargs)

    def assert_forward_links(self, obj, expected, attr='member', **kwargs):
        self.assert_links(obj, expected, attr=attr,
                          msg='%s forward links do not match for %s' %
                          (attr, obj),
                          **kwargs)

    def get_object_guid(self, dn):
        res = self.samdb.search(dn,
                                scope=ldb.SCOPE_BASE,
                                attrs=['objectGUID'])
        return str(misc.GUID(res[0]['objectGUID'][0]))

    def assertRaisesLdbError(self, errcode, message, f, *args, **kwargs):
        """Assert a function raises a particular LdbError."""
        try:
            f(*args, **kwargs)
        except ldb.LdbError as e:
            (num, msg) = e.args
            if num != errcode:
                lut = {v: k for k, v in vars(ldb).items()
                       if k.startswith('ERR_') and isinstance(v, int)}
                self.fail("%s, expected "
                          "LdbError %s, (%d) "
                          "got %s (%d) "
                          "%s" % (message,
                                  lut.get(errcode), errcode,
                                  lut.get(num), num,
                                  msg))
        else:
            lut = {v: k for k, v in vars(ldb).items()
                   if k.startswith('ERR_') and isinstance(v, int)}
            self.fail("%s, expected "
                      "LdbError %s, (%d) "
                      "but we got success" % (message,
                                              lut.get(errcode),
                                              errcode))

    def test_la_move_ou_tree(self):
        tag = 'move_tree'

        u1, u2 = self.add_objects(2, 'user', '%s_u_' % tag, ou=self.ou1)
        g1, g2 = self.add_objects(2, 'group', '%s_g_' % tag, ou=self.ou1)
        c1, c2, c3 = self.add_objects(3, 'computer',
                                      '%s_c_' % tag,
                                      ou=self.ou1)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g1, g2)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)
        c1u1 = self.add_binary_link(c1, u1, 'a').replace(self.ou1, self.ou3)
        c2u1 = self.add_binary_link(c2, u1, 'b').replace(self.ou1, self.ou3)
        c3u1 = self.add_binary_link(c3, u1, 124.543).replace(self.ou1, self.ou3)
        c1g1 = self.add_binary_link(c1, g1, 'd').replace(self.ou1, self.ou3)
        c2g2 = self.add_binary_link(c2, g2, 'd').replace(self.ou1, self.ou3)
        c2c1 = self.add_binary_link(c2, c1, 'd').replace(self.ou1, self.ou3)
        c1u2 = self.add_binary_link(c1, u2, 'd').replace(self.ou1, self.ou3)
        c1u1_2 = self.add_binary_link(c1, u1, 'b').replace(self.ou1, self.ou3)

        self.assertRaisesLdbError(20,
                                  "Attribute msDS-RevealedUsers already exists",
                                  self.add_binary_link, c1, u2, 'd')

        self.samdb.rename(self.ou1, self.ou3)
        debug(colour.c_CYAN("rename FINISHED"))
        u1, u2, g1, g2, c1, c2, c3 = [x.replace(self.ou1, self.ou3)
                                      for x in (u1, u2, g1, g2, c1, c2, c3)]

        self.samdb.delete(g2, ['tree_delete:1'])

        self.assert_forward_links(g1, [u1])
        self.assert_back_links(u1, [g1])
        self.assert_back_links(u2, set())
        self.assert_forward_links(c1, [c1u1, c1u1_2, c1u2, c1g1],
                                  attr='msDS-RevealedUsers')
        self.assert_forward_links(c2, [c2u1, c2c1], attr='msDS-RevealedUsers')
        self.assert_forward_links(c3, [c3u1], attr='msDS-RevealedUsers')
        self.assert_back_links(u1, [c1, c1, c2, c3], attr='msDS-RevealedDSAs')
        self.assert_back_links(u2, [c1], attr='msDS-RevealedDSAs')
        self.assert_back_links(g1, [c1], attr='msDS-RevealedDSAs')
        self.assert_back_links(c1, [c2], attr='msDS-RevealedDSAs')

    def test_la_move_ou_groups(self):
        tag = 'move_groups'

        u1, u2 = self.add_objects(2, 'user', '%s_u_' % tag, ou=self.ou2)
        g1, g2 = self.add_objects(2, 'group', '%s_g_' % tag, ou=self.ou1)
        c1, c2, c3 = self.add_objects(3, 'computer',
                                      '%s_c_' % tag,
                                      ou=self.ou1)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g1, g2)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)
        c1u1 = self.add_binary_link(c1, u1, 'a').replace(self.ou1, self.ou3)
        c2u1 = self.add_binary_link(c2, u1, 'b').replace(self.ou1, self.ou3)
        c3u1 = self.add_binary_link(c3, u1, 124.543).replace(self.ou1, self.ou3)
        c1g1 = self.add_binary_link(c1, g1, 'd').replace(self.ou1, self.ou3)
        c2g2 = self.add_binary_link(c2, g2, 'd').replace(self.ou1, self.ou3)
        c2c1 = self.add_binary_link(c2, c1, 'd').replace(self.ou1, self.ou3)
        c1u2 = self.add_binary_link(c1, u2, 'd').replace(self.ou1, self.ou3)
        c1u1_2 = self.add_binary_link(c1, u1, 'b').replace(self.ou1, self.ou3)

        self.samdb.rename(self.ou1, self.ou3)
        debug(colour.c_CYAN("rename FINISHED"))
        u1, u2, g1, g2, c1, c2, c3 = [x.replace(self.ou1, self.ou3)
                                      for x in (u1, u2, g1, g2, c1, c2, c3)]

        self.samdb.delete(g2, ['tree_delete:1'])

        self.assert_forward_links(g1, [u1])
        self.assert_back_links(u1, [g1])
        self.assert_back_links(u2, set())
        self.assert_forward_links(c1, [c1u1, c1u1_2, c1u2, c1g1],
                                  attr='msDS-RevealedUsers')
        self.assert_forward_links(c2, [c2u1, c2c1], attr='msDS-RevealedUsers')
        self.assert_forward_links(c3, [c3u1], attr='msDS-RevealedUsers')
        self.assert_back_links(u1, [c1, c1, c2, c3], attr='msDS-RevealedDSAs')
        self.assert_back_links(u2, [c1], attr='msDS-RevealedDSAs')
        self.assert_back_links(g1, [c1], attr='msDS-RevealedDSAs')
        self.assert_back_links(c1, [c2], attr='msDS-RevealedDSAs')

    def test_la_move_ou_users(self):
        tag = 'move_users'

        u1, u2 = self.add_objects(2, 'user', '%s_u_' % tag, ou=self.ou1)
        g1, g2 = self.add_objects(2, 'group', '%s_g_' % tag, ou=self.ou2)
        c1, c2 = self.add_objects(2, 'computer', '%s_c_' % tag, ou=self.ou1)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g1, g2)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)
        c1u1 = self.add_binary_link(c1, u1, 'a').replace(self.ou1, self.ou3)
        c2u1 = self.add_binary_link(c2, u1, 'b').replace(self.ou1, self.ou3)
        c1g1 = self.add_binary_link(c1, g1, 'd').replace(self.ou1, self.ou3)
        c2g2 = self.add_binary_link(c2, g2, 'd').replace(self.ou1, self.ou3)
        c2c1 = self.add_binary_link(c2, c1, 'd').replace(self.ou1, self.ou3)
        c1u2 = self.add_binary_link(c1, u2, 'd').replace(self.ou1, self.ou3)
        c1u1_2 = self.add_binary_link(c1, u1, 'b').replace(self.ou1, self.ou3)


        self.samdb.rename(self.ou1, self.ou3)
        debug(colour.c_CYAN("rename FINISHED"))
        u1, u2, g1, g2, c1, c2 = [x.replace(self.ou1, self.ou3)
                                  for x in (u1, u2, g1, g2, c1, c2)]

        self.samdb.delete(g2, ['tree_delete:1'])

        self.assert_forward_links(g1, [u1])
        self.assert_back_links(u1, [g1])
        self.assert_back_links(u2, set())
        self.assert_forward_links(c1, [c1u1, c1u1_2, c1u2, c1g1],
                                  attr='msDS-RevealedUsers')
        self.assert_forward_links(c2, [c2u1, c2c1], attr='msDS-RevealedUsers')
        self.assert_back_links(u1, [c1, c1, c2], attr='msDS-RevealedDSAs')
        self.assert_back_links(u2, [c1], attr='msDS-RevealedDSAs')
        self.assert_back_links(g1, [c1], attr='msDS-RevealedDSAs')
        self.assert_back_links(c1, [c2], attr='msDS-RevealedDSAs')

    def test_la_move_ou_noncomputers(self):
        """Here we are especially testing the msDS-RevealedDSAs links"""
        tag = 'move_noncomputers'

        u1, u2 = self.add_objects(2, 'user', '%s_u_' % tag, ou=self.ou1)
        g1, g2 = self.add_objects(2, 'group', '%s_g_' % tag, ou=self.ou1)
        c1, c2, c3 = self.add_objects(3, 'computer', '%s_c_' % tag, ou=self.ou2)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g1, g2)
        c1u1 = self.add_binary_link(c1, u1, 'a').replace(self.ou1, self.ou3)
        c2u1 = self.add_binary_link(c2, u1, 'b').replace(self.ou1, self.ou3)
        c2u1_2 = self.add_binary_link(c2, u1, 'c').replace(self.ou1, self.ou3)
        c3u1 = self.add_binary_link(c3, g1, 'b').replace(self.ou1, self.ou3)
        c1g1 = self.add_binary_link(c1, g1, 'd').replace(self.ou1, self.ou3)
        c2g2 = self.add_binary_link(c2, g2, 'd').replace(self.ou1, self.ou3)
        c2c1 = self.add_binary_link(c2, c1, 'd').replace(self.ou1, self.ou3)
        c1u2 = self.add_binary_link(c1, u2, 'd').replace(self.ou1, self.ou3)
        c1u1_2 = self.add_binary_link(c1, u1, 'b').replace(self.ou1, self.ou3)
        c1u1_3 = self.add_binary_link(c1, u1, 'c').replace(self.ou1, self.ou3)
        c2u1_3 = self.add_binary_link(c2, u1, 'e').replace(self.ou1, self.ou3)
        c3u2 = self.add_binary_link(c3, u2, 'b').replace(self.ou1, self.ou3)

        self.samdb.rename(self.ou1, self.ou3)
        debug(colour.c_CYAN("rename FINISHED"))
        u1, u2, g1, g2, c1, c2, c3 = [x.replace(self.ou1, self.ou3)
                                      for x in (u1, u2, g1, g2, c1, c2, c3)]

        self.samdb.delete(c3, ['tree_delete:1'])

        self.assert_forward_links(g1, [g2, u1])
        self.assert_back_links(u1, [g1])
        self.assert_back_links(u2, [])
        self.assert_forward_links(c1, [c1u1, c1u1_2, c1u1_3, c1u2, c1g1],
                                  attr='msDS-RevealedUsers')
        self.assert_forward_links(c2, [c2u1, c2u1_2, c2u1_3, c2c1, c2g2],
                                  attr='msDS-RevealedUsers')
        self.assert_back_links(u1, [c1, c1, c1, c2, c2, c2],
                               attr='msDS-RevealedDSAs')
        self.assert_back_links(u2, [c1], attr='msDS-RevealedDSAs')
        self.assert_back_links(g1, [c1], attr='msDS-RevealedDSAs')
        self.assert_back_links(c1, [c2], attr='msDS-RevealedDSAs')

    def test_la_move_ou_tree_big(self):
        tag = 'move_ou_big'
        USERS, GROUPS, COMPUTERS = 50, 10, 7

        users = self.add_objects(USERS, 'user', '%s_u_' % tag, ou=self.ou1)
        groups = self.add_objects(GROUPS, 'group', '%s_g_' % tag, ou=self.ou1)
        computers = self.add_objects(COMPUTERS, 'computer', '%s_c_' % tag,
                                     ou=self.ou1)

        start = time()
        for i in range(USERS):
            u = users[i]
            for j in range(i % GROUPS):
                g = groups[j]
                self.add_linked_attribute(g, u)
            for j in range(i % COMPUTERS):
                c = computers[j]
                self.add_binary_link(c, u, 'a')

        debug("linking took %.3fs" % (time() - start))
        start = time()
        self.samdb.rename(self.ou1, self.ou3)
        debug("rename ou took %.3fs" % (time() - start))

        g1 = groups[0].replace(self.ou1, self.ou3)
        start = time()
        self.samdb.rename(g1, g1.replace(self.ou3, self.ou2))
        debug("rename group took %.3fs" % (time() - start))

        u1 = users[0].replace(self.ou1, self.ou3)
        start = time()
        self.samdb.rename(u1, u1.replace(self.ou3, self.ou2))
        debug("rename user took %.3fs" % (time() - start))

        c1 = computers[0].replace(self.ou1, self.ou3)
        start = time()
        self.samdb.rename(c1, c1.replace(self.ou3, self.ou2))
        debug("rename computer took %.3fs" % (time() - start))
class UserTests(samba.tests.TestCase):

    def add_if_possible(self, *args, **kwargs):
        """In these tests sometimes things are left in the database
        deliberately, so we don't worry if we fail to add them a second
        time."""
        try:
            self.ldb.add(*args, **kwargs)
        except LdbError:
            pass

    def setUp(self):
        super(UserTests, self).setUp()
        self.state = GlobalState  # the class itself, not an instance
        self.lp = lp
        self.ldb = SamDB(host, credentials=creds,
                         session_info=system_session(lp), lp=lp)
        self.base_dn = self.ldb.domain_dn()
        self.ou = "OU=pid%s,%s" % (os.getpid(), self.base_dn)
        self.ou_users = "OU=users,%s" % self.ou
        self.ou_groups = "OU=groups,%s" % self.ou
        self.ou_computers = "OU=computers,%s" % self.ou

        for dn in (self.ou, self.ou_users, self.ou_groups,
                   self.ou_computers):
            self.add_if_possible({
                "dn": dn,
                "objectclass": "organizationalUnit"})

    def tearDown(self):
        super(UserTests, self).tearDown()

    def test_00_00_do_nothing(self):
        # this gives us an idea of the overhead
        pass

    def _prepare_n_groups(self, n):
        self.state.n_groups = n
        for i in range(n):
            self.add_if_possible({
                "dn": "cn=g%d,%s" % (i, self.ou_groups),
                "objectclass": "group"})

    def _add_users(self, start, end):
        for i in range(start, end):
            self.ldb.add({
                "dn": "cn=u%d,%s" % (i, self.ou_users),
                "objectclass": "user"})

    def _test_join(self):
        tmpdir = tempfile.mkdtemp()
        if '://' in host:
            server = host.split('://', 1)[1]
        else:
            server = host
        cmd = cmd_sambatool.subcommands['domain'].subcommands['join']
        result = cmd._run("samba-tool domain join",
                          creds.get_realm(),
                          "dc", "-U%s%%%s" % (creds.get_username(),
                                              creds.get_password()),
                          '--targetdir=%s' % tmpdir,
                          '--server=%s' % server)

        shutil.rmtree(tmpdir)

    def _test_unindexed_search(self):
        expressions = [
            ('(&(objectclass=user)(description='
             'Built-in account for adminstering the computer/domain))'),
            '(description=Built-in account for adminstering the computer/domain)',
            '(objectCategory=*)',
            '(samaccountname=Administrator*)'
        ]
        for expression in expressions:
            t = time.time()
            for i in range(10):
                self.ldb.search(self.ou,
                                expression=expression,
                                scope=SCOPE_SUBTREE,
                                attrs=['cn'])
            print('%d %s took %s' % (i, expression,
                                     time.time() - t), file=sys.stderr)

    def _test_indexed_search(self):
        expressions = ['(objectclass=group)',
                       '(samaccountname=Administrator)'
                       ]
        for expression in expressions:
            t = time.time()
            for i in range(100):
                self.ldb.search(self.ou,
                                expression=expression,
                                scope=SCOPE_SUBTREE,
                                attrs=['cn'])
            print('%d runs %s took %s' % (i, expression,
                                          time.time() - t), file=sys.stderr)

    def _test_add_many_users(self, n=BATCH_SIZE):
        s = self.state.next_user_id
        e = s + n
        self._add_users(s, e)
        self.state.next_user_id = e

    test_00_00_join_empty_dc = _test_join

    test_00_01_adding_users_1000 = _test_add_many_users
    test_00_02_adding_users_2000 = _test_add_many_users
    test_00_03_adding_users_3000 = _test_add_many_users

    test_00_10_join_unlinked_dc = _test_join
    test_00_11_unindexed_search_3k_users = _test_unindexed_search
    test_00_12_indexed_search_3k_users = _test_indexed_search

    def _link_user_and_group(self, u, g):
        m = Message()
        m.dn = Dn(self.ldb, "CN=g%d,%s" % (g, self.ou_groups))
        m["member"] = MessageElement("cn=u%d,%s" % (u, self.ou_users),
                                     FLAG_MOD_ADD, "member")
        self.ldb.modify(m)

    def _unlink_user_and_group(self, u, g):
        user = "******" % (u, self.ou_users)
        group = "CN=g%d,%s" % (g, self.ou_groups)
        m = Message()
        m.dn = Dn(self.ldb, group)
        m["member"] = MessageElement(user, FLAG_MOD_DELETE, "member")
        self.ldb.modify(m)

    def _test_link_many_users(self, n=BATCH_SIZE):
        self._prepare_n_groups(N_GROUPS)
        s = self.state.next_linked_user
        e = s + n
        for i in range(s, e):
            g = i % N_GROUPS
            self._link_user_and_group(i, g)
        self.state.next_linked_user = e

    test_01_01_link_users_1000 = _test_link_many_users
    test_01_02_link_users_2000 = _test_link_many_users
    test_01_03_link_users_3000 = _test_link_many_users

    def _test_link_many_users_offset_1(self, n=BATCH_SIZE):
        s = self.state.next_relinked_user
        e = s + n
        for i in range(s, e):
            g = (i + 1) % N_GROUPS
            self._link_user_and_group(i, g)
        self.state.next_relinked_user = e

    test_02_01_link_users_again_1000 = _test_link_many_users_offset_1
    test_02_02_link_users_again_2000 = _test_link_many_users_offset_1
    test_02_03_link_users_again_3000 = _test_link_many_users_offset_1

    test_02_10_join_partially_linked_dc = _test_join
    test_02_11_unindexed_search_partially_linked_dc = _test_unindexed_search
    test_02_12_indexed_search_partially_linked_dc = _test_indexed_search

    def _test_link_many_users_3_groups(self, n=BATCH_SIZE, groups=3):
        s = self.state.next_linked_user_3
        e = s + n
        self.state.next_linked_user_3 = e
        for i in range(s, e):
            g = (i + 2) % groups
            if g not in (i % N_GROUPS, (i + 1) % N_GROUPS):
                self._link_user_and_group(i, g)

    test_03_01_link_users_again_1000_few_groups = _test_link_many_users_3_groups
    test_03_02_link_users_again_2000_few_groups = _test_link_many_users_3_groups
    test_03_03_link_users_again_3000_few_groups = _test_link_many_users_3_groups

    def _test_remove_links_0(self, n=BATCH_SIZE):
        s = self.state.next_removed_link_0
        e = s + n
        self.state.next_removed_link_0 = e
        for i in range(s, e):
            g = i % N_GROUPS
            self._unlink_user_and_group(i, g)

    test_04_01_remove_some_links_1000 = _test_remove_links_0
    test_04_02_remove_some_links_2000 = _test_remove_links_0
    test_04_03_remove_some_links_3000 = _test_remove_links_0

    # back to using _test_add_many_users
    test_05_01_adding_users_after_links_4000 = _test_add_many_users

    # reset the link count, to replace the original links
    def test_06_01_relink_users_1000(self):
        self.state.next_linked_user = 0
        self._test_link_many_users()

    test_06_02_link_users_2000 = _test_link_many_users
    test_06_03_link_users_3000 = _test_link_many_users
    test_06_04_link_users_4000 = _test_link_many_users
    test_06_05_link_users_again_4000 = _test_link_many_users_offset_1
    test_06_06_link_users_again_4000_few_groups = _test_link_many_users_3_groups

    test_07_01_adding_users_after_links_5000 = _test_add_many_users

    def _test_link_random_users_and_groups(self, n=BATCH_SIZE, groups=100):
        self._prepare_n_groups(groups)
        for i in range(n):
            u = random.randrange(self.state.next_user_id)
            g = random.randrange(groups)
            try:
                self._link_user_and_group(u, g)
            except LdbError:
                pass

    test_08_01_link_random_users_100_groups = _test_link_random_users_and_groups
    test_08_02_link_random_users_100_groups = _test_link_random_users_and_groups

    test_10_01_unindexed_search_full_dc = _test_unindexed_search
    test_10_02_indexed_search_full_dc = _test_indexed_search
    test_11_02_join_full_dc = _test_join

    def test_20_01_delete_50_groups(self):
        for i in range(self.state.n_groups - 50, self.state.n_groups):
            self.ldb.delete("cn=g%d,%s" % (i, self.ou_groups))
        self.state.n_groups -= 50

    def _test_delete_many_users(self, n=BATCH_SIZE):
        e = self.state.next_user_id
        s = max(0, e - n)
        self.state.next_user_id = s
        for i in range(s, e):
            self.ldb.delete("cn=u%d,%s" % (i, self.ou_users))

    test_21_01_delete_users_5000_lightly_linked = _test_delete_many_users
    test_21_02_delete_users_4000_lightly_linked = _test_delete_many_users
    test_21_03_delete_users_3000 = _test_delete_many_users

    def test_22_01_delete_all_groups(self):
        for i in range(self.state.n_groups):
            self.ldb.delete("cn=g%d,%s" % (i, self.ou_groups))
        self.state.n_groups = 0

    test_23_01_delete_users_after_groups_2000 = _test_delete_many_users
    test_23_00_delete_users_after_groups_1000 = _test_delete_many_users

    test_24_02_join_after_cleanup = _test_join
Exemple #21
0
class UserTests(samba.tests.TestCase):
    def add_if_possible(self, *args, **kwargs):
        """In these tests sometimes things are left in the database
        deliberately, so we don't worry if we fail to add them a second
        time."""
        try:
            self.ldb.add(*args, **kwargs)
        except LdbError:
            pass

    def setUp(self):
        super(UserTests, self).setUp()
        self.state = GlobalState  # the class itself, not an instance
        self.lp = lp
        self.ldb = SamDB(host,
                         credentials=creds,
                         session_info=system_session(lp),
                         lp=lp)
        self.base_dn = self.ldb.domain_dn()
        self.ou = "OU=pid%s,%s" % (os.getpid(), self.base_dn)
        self.ou_users = "OU=users,%s" % self.ou
        self.ou_groups = "OU=groups,%s" % self.ou
        self.ou_computers = "OU=computers,%s" % self.ou

        self.state.test_number += 1
        random.seed(self.state.test_number)

    def tearDown(self):
        super(UserTests, self).tearDown()

    def test_00_00_do_nothing(self):
        # this gives us an idea of the overhead
        pass

    def test_00_01_do_nothing_relevant(self):
        # takes around 1 second on i7-4770
        j = 0
        for i in range(30000000):
            j += i

    def test_00_02_do_nothing_sleepily(self):
        time.sleep(1)

    def test_00_03_add_ous_and_groups(self):
        # initialise the database
        for dn in (self.ou, self.ou_users, self.ou_groups, self.ou_computers):
            self.ldb.add({"dn": dn, "objectclass": "organizationalUnit"})

        for i in range(N_GROUPS):
            self.ldb.add({
                "dn": "cn=g%d,%s" % (i, self.ou_groups),
                "objectclass": "group"
            })

        self.state.n_groups = N_GROUPS

    def _add_users(self, start, end):
        for i in range(start, end):
            self.ldb.add({
                "dn": "cn=u%d,%s" % (i, self.ou_users),
                "objectclass": "user"
            })

    def _add_users_ldif(self, start, end):
        lines = []
        for i in range(start, end):
            lines.append("dn: cn=u%d,%s" % (i, self.ou_users))
            lines.append("objectclass: user")
            lines.append("")
        self.ldb.add_ldif('\n'.join(lines))

    def _test_join(self):
        tmpdir = tempfile.mkdtemp()
        if '://' in host:
            server = host.split('://', 1)[1]
        else:
            server = host
        cmd = cmd_sambatool.subcommands['domain'].subcommands['join']
        result = cmd._run(
            "samba-tool domain join", creds.get_realm(), "dc",
            "-U%s%%%s" % (creds.get_username(), creds.get_password()),
            '--targetdir=%s' % tmpdir, '--server=%s' % server)

        shutil.rmtree(tmpdir)

    def _test_unindexed_search(self):
        expressions = [(
            '(&(objectclass=user)(description='
            'Built-in account for adminstering the computer/domain))'
        ), '(description=Built-in account for adminstering the computer/domain)',
                       '(objectCategory=*)', '(samaccountname=Administrator*)']
        for expression in expressions:
            t = time.time()
            for i in range(25):
                self.ldb.search(self.ou,
                                expression=expression,
                                scope=SCOPE_SUBTREE,
                                attrs=['cn'])
            print('%d %s took %s' % (i, expression, time.time() - t),
                  file=sys.stderr)

    def _test_indexed_search(self):
        expressions = ['(objectclass=group)', '(samaccountname=Administrator)']
        for expression in expressions:
            t = time.time()
            for i in range(4000):
                self.ldb.search(self.ou,
                                expression=expression,
                                scope=SCOPE_SUBTREE,
                                attrs=['cn'])
            print('%d runs %s took %s' % (i, expression, time.time() - t),
                  file=sys.stderr)

    def _test_base_search(self):
        for dn in [
                self.base_dn, self.ou, self.ou_users, self.ou_groups,
                self.ou_computers
        ]:
            for i in range(4000):
                try:
                    self.ldb.search(dn, scope=SCOPE_BASE, attrs=['cn'])
                except LdbError as e:
                    (num, msg) = e.args
                    if num != ERR_NO_SUCH_OBJECT:
                        raise

    def _test_base_search_failing(self):
        pattern = 'missing%d' + self.ou
        for i in range(4000):
            try:
                self.ldb.search(pattern % i, scope=SCOPE_BASE, attrs=['cn'])
            except LdbError as e:
                (num, msg) = e
                if num != ERR_NO_SUCH_OBJECT:
                    raise

    def search_expression_list(self,
                               expressions,
                               rounds,
                               attrs=['cn'],
                               scope=SCOPE_SUBTREE):
        for expression in expressions:
            t = time.time()
            for i in range(rounds):
                self.ldb.search(self.ou,
                                expression=expression,
                                scope=SCOPE_SUBTREE,
                                attrs=['cn'])
            print('%d runs %s took %s' % (i, expression, time.time() - t),
                  file=sys.stderr)

    def _test_complex_search(self, n=100):
        classes = ['samaccountname', 'objectCategory', 'dn', 'member']
        values = ['*', '*t*', 'g*', 'user']
        comparators = ['=', '<=', '>=']  # '~=' causes error
        maybe_not = ['!(', '']
        joiners = ['&', '|']

        # The number of permuations is 18432, which is not huge but
        # would take hours to search. So we take a sample.
        all_permutations = list(
            itertools.product(joiners, classes, classes, values, values,
                              comparators, comparators, maybe_not, maybe_not))

        expressions = []

        for (j, c1, c2, v1, v2, o1, o2, n1,
             n2) in random.sample(all_permutations, n):
            expression = ''.join([
                '(', j, '(', n1, c1, o1, v1, '))' if n1 else ')', '(', n2, c2,
                o2, v2, '))' if n2 else ')', ')'
            ])
            expressions.append(expression)

        self.search_expression_list(expressions, 1)

    def _test_member_search(self, rounds=10):
        expressions = []
        for d in range(20):
            expressions.append('(member=cn=u%d,%s)' % (d + 500, self.ou_users))
            expressions.append('(member=u%d*)' % (d + 700, ))

        self.search_expression_list(expressions, rounds)

    def _test_memberof_search(self, rounds=200):
        expressions = []
        for i in range(min(self.state.n_groups, rounds)):
            expressions.append('(memberOf=cn=g%d,%s)' % (i, self.ou_groups))
            expressions.append('(memberOf=cn=g%d*)' % (i, ))
            expressions.append('(memberOf=cn=*%s*)' % self.ou_groups)

        self.search_expression_list(expressions, 2)

    def _test_add_many_users(self, n=BATCH_SIZE):
        s = self.state.next_user_id
        e = s + n
        self._add_users(s, e)
        self.state.next_user_id = e

    def _test_add_many_users_ldif(self, n=BATCH_SIZE):
        s = self.state.next_user_id
        e = s + n
        self._add_users_ldif(s, e)
        self.state.next_user_id = e

    def _link_user_and_group(self, u, g):
        link = (u, g)
        if link in self.state.active_links:
            return False

        m = Message()
        m.dn = Dn(self.ldb, "CN=g%d,%s" % (g, self.ou_groups))
        m["member"] = MessageElement("cn=u%d,%s" % (u, self.ou_users),
                                     FLAG_MOD_ADD, "member")
        self.ldb.modify(m)
        self.state.active_links.add(link)
        return True

    def _unlink_user_and_group(self, u, g):
        link = (u, g)
        if link not in self.state.active_links:
            return False

        user = "******" % (u, self.ou_users)
        group = "CN=g%d,%s" % (g, self.ou_groups)
        m = Message()
        m.dn = Dn(self.ldb, group)
        m["member"] = MessageElement(user, FLAG_MOD_DELETE, "member")
        self.ldb.modify(m)
        self.state.active_links.remove(link)
        return True

    def _test_link_many_users(self, n=LINK_BATCH_SIZE):
        # this links unevenly, putting more users in the first group
        # and fewer in the last.
        ng = self.state.n_groups
        nu = self.state.next_user_id
        while n:
            u = random.randrange(nu)
            g = random.randrange(random.randrange(ng) + 1)
            if self._link_user_and_group(u, g):
                n -= 1

    def _test_link_many_users_batch(self, n=(LINK_BATCH_SIZE * 10)):
        # this links unevenly, putting more users in the first group
        # and fewer in the last.
        ng = self.state.n_groups
        nu = self.state.next_user_id
        messages = []
        for g in range(ng):
            m = Message()
            m.dn = Dn(self.ldb, "CN=g%d,%s" % (g, self.ou_groups))
            messages.append(m)

        while n:
            u = random.randrange(nu)
            g = random.randrange(random.randrange(ng) + 1)
            link = (u, g)
            if link in self.state.active_links:
                continue
            m = messages[g]
            m["member%s" % u] = MessageElement(
                "cn=u%d,%s" % (u, self.ou_users), FLAG_MOD_ADD, "member")
            self.state.active_links.add(link)
            n -= 1

        for m in messages:
            try:
                self.ldb.modify(m)
            except LdbError as e:
                print(e)
                print(m)

    def _test_remove_some_links(self, n=(LINK_BATCH_SIZE // 2)):
        victims = random.sample(list(self.state.active_links), n)
        for x in victims:
            self._unlink_user_and_group(*x)

    test_00_11_join_empty_dc = _test_join

    test_00_12_adding_users_2000 = _test_add_many_users

    test_00_20_join_unlinked_2k_users = _test_join
    test_00_21_unindexed_search_2k_users = _test_unindexed_search
    test_00_22_indexed_search_2k_users = _test_indexed_search

    test_00_23_complex_search_2k_users = _test_complex_search
    test_00_24_member_search_2k_users = _test_member_search
    test_00_25_memberof_search_2k_users = _test_memberof_search

    test_00_27_base_search_2k_users = _test_base_search
    test_00_28_base_search_failing_2k_users = _test_base_search_failing

    test_01_01_link_2k_users = _test_link_many_users
    test_01_02_link_2k_users_batch = _test_link_many_users_batch

    test_02_10_join_2k_linked_dc = _test_join
    test_02_11_unindexed_search_2k_linked_dc = _test_unindexed_search
    test_02_12_indexed_search_2k_linked_dc = _test_indexed_search

    test_04_01_remove_some_links_2k = _test_remove_some_links

    test_05_01_adding_users_after_links_4k_ldif = _test_add_many_users_ldif

    test_06_04_link_users_4k = _test_link_many_users
    test_06_05_link_users_4k_batch = _test_link_many_users_batch

    test_07_01_adding_users_after_links_6k = _test_add_many_users

    def _test_ldif_well_linked_group(self, link_chance=1.0):
        g = self.state.n_groups
        self.state.n_groups += 1
        lines = ["dn: CN=g%d,%s" % (g, self.ou_groups), "objectclass: group"]

        for i in range(self.state.next_user_id):
            if random.random() <= link_chance:
                lines.append("member: cn=u%d,%s" % (i, self.ou_users))
                self.state.active_links.add((i, g))

        lines.append("")
        self.ldb.add_ldif('\n'.join(lines))

    test_09_01_add_fully_linked_group = _test_ldif_well_linked_group

    def test_09_02_add_exponentially_diminishing_linked_groups(self):
        linkage = 0.8
        while linkage > 0.01:
            self._test_ldif_well_linked_group(linkage)
            linkage *= 0.75

    test_09_04_link_users_6k = _test_link_many_users

    test_10_01_unindexed_search_6k_users = _test_unindexed_search
    test_10_02_indexed_search_6k_users = _test_indexed_search

    test_10_27_base_search_6k_users = _test_base_search
    test_10_28_base_search_failing_6k_users = _test_base_search_failing

    def test_10_03_complex_search_6k_users(self):
        self._test_complex_search(n=50)

    def test_10_04_member_search_6k_users(self):
        self._test_member_search(rounds=1)

    def test_10_05_memberof_search_6k_users(self):
        self._test_memberof_search(rounds=5)

    test_11_02_join_full_dc = _test_join

    test_12_01_remove_some_links_6k = _test_remove_some_links

    def _test_delete_many_users(self, n=DELETE_BATCH_SIZE):
        e = self.state.next_user_id
        s = max(0, e - n)
        self.state.next_user_id = s
        for i in range(s, e):
            self.ldb.delete("cn=u%d,%s" % (i, self.ou_users))

        for x in tuple(self.state.active_links):
            if s >= x[0] > e:
                self.state.active_links.remove(x)

    test_20_01_delete_users_6k = _test_delete_many_users

    def test_21_01_delete_10_groups(self):
        for i in range(self.state.n_groups - 10, self.state.n_groups):
            self.ldb.delete("cn=g%d,%s" % (i, self.ou_groups))
        self.state.n_groups -= 10
        for x in tuple(self.state.active_links):
            if x[1] >= self.state.n_groups:
                self.state.active_links.remove(x)

    test_21_02_delete_users_5950 = _test_delete_many_users

    def test_22_01_delete_all_groups(self):
        for i in range(self.state.n_groups):
            self.ldb.delete("cn=g%d,%s" % (i, self.ou_groups))
        self.state.n_groups = 0
        self.state.active_links = set()

    # XXX assert the state is as we think, using searches

    def test_23_01_delete_users_5900_after_groups(self):
        # we do not delete everything because it takes too long
        n = 4 * DELETE_BATCH_SIZE
        self._test_delete_many_users(n=n)

    test_24_02_join_after_partial_cleanup = _test_join
Exemple #22
0
class VLVTests(samba.tests.TestCase):

    def create_user(self, i, n, prefix='vlvtest', suffix='', attrs=None):
        name = "%s%d%s" % (prefix, i, suffix)
        user = {
            'cn': name,
            "objectclass": "user",
            'givenName': "abcdefghijklmnopqrstuvwxyz"[i % 26],
            "roomNumber": "%sbc" % (n - i),
            "carLicense": "后来经",
            "employeeNumber": "%s%sx" % (abs(i * (99 - i)), '\n' * (i & 255)),
            "accountExpires": "%s" % (10 ** 9 + 1000000 * i),
            "msTSExpireDate4": "19%02d0101010000.0Z" % (i % 100),
            "flags": str(i * (n - i)),
            "serialNumber": "abc %s%s%s" % ('AaBb |-/'[i & 7],
                                            ' 3z}'[i & 3],
                                            '"@'[i & 1],),
        }

        # _user_broken_attrs tests are broken due to problems outside
        # of VLV.
        _user_broken_attrs = {
            # Sort doesn't look past a NUL byte.
            "photo": "\x00%d" % (n - i),
            "audio": "%sn octet string %s%s ♫♬\x00lalala" % ('Aa'[i & 1],
                                                             chr(i & 255), i),
            "displayNamePrintable": "%d\x00%c" % (i, i & 255),
            "adminDisplayName": "%d\x00b" % (n-i),
            "title": "%d%sb" % (n - i, '\x00' * i),
            "comment": "Favourite colour is %d" % (n % (i + 1)),

            # Names that vary only in case. Windows returns
            # equivalent addresses in the order they were put
            # in ('a st', 'A st',...).
            "street": "%s st" % (chr(65 | (i & 14) | ((i & 1) * 32))),
        }

        if attrs is not None:
            user.update(attrs)

        user['dn'] = "cn=%s,%s" % (user['cn'], self.ou)

        if opts.skip_attr_regex:
            match = re.compile(opts.skip_attr_regex).search
            for k in user.keys():
                if match(k):
                    del user[k]

        self.users.append(user)
        self.ldb.add(user)
        return user

    def setUp(self):
        super(VLVTests, self).setUp()
        self.ldb = SamDB(host, credentials=creds,
                         session_info=system_session(lp), lp=lp)

        self.base_dn = self.ldb.domain_dn()
        self.ou = "ou=vlv,%s" % self.base_dn
        if opts.delete_in_setup:
            try:
                self.ldb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError, e:
                print "tried deleting %s, got error %s" % (self.ou, e)
        self.ldb.add({
            "dn": self.ou,
            "objectclass": "organizationalUnit"})

        self.users = []
        for i in range(N_ELEMENTS):
            self.create_user(i, N_ELEMENTS)

        attrs = self.users[0].keys()
        self.binary_sorted_keys = ['audio',
                                   'photo',
                                   "msTSExpireDate4",
                                   'serialNumber',
                                   "displayNamePrintable"]

        self.numeric_sorted_keys = ['flags',
                                    'accountExpires']

        self.timestamp_keys = ['msTSExpireDate4']

        self.int64_keys = set(['accountExpires'])

        self.locale_sorted_keys = [x for x in attrs if
                                   x not in (self.binary_sorted_keys +
                                             self.numeric_sorted_keys)]

        # don't try spaces, etc in cn
        self.delicate_keys = ['cn']
Exemple #23
0
class DsdbTests(TestCase):

    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

        # Create a test user
        user_name = "dsdb-user-" + str(uuid.uuid4().hex[0:6])
        user_pass = samba.generate_random_password(32, 32)
        user_description = "Test user for dsdb test"

        base_dn = self.samdb.domain_dn()

        self.account_dn = "cn=" + user_name + ",cn=Users," + base_dn
        self.samdb.newuser(username=user_name,
                           password=user_pass,
                           description=user_description)
        # Cleanup (teardown)
        self.addCleanup(delete_force, self.samdb, self.account_dn)

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEquals(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0", "local_oid:1.3.6.1.4.1.7165.4.3.25:0"])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val", ldb.FLAG_MOD_REPLACE, "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(13), "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 2)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description", version + 2)
        self.assertEqual(self.samdb.get_attribute_replmetadata_version(dn, "description"), version + 2)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                    base=self.account_dn,
                                    attrs=["replPropertyMetaData"],
                                    controls=["local_oid:%s:0"
                                              % dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                    base=self.account_dn,
                                    attrs=["replPropertyMetaData"],
                                    controls=["local_oid:%s:1"
                                              % dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED])
        except ldb.LdbError as e:
            (errno, estr) = e.args
            if errno != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail("Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                          % e[1])

    # Allocate a unique RID for use in the objectSID tests.
    #
    def allocate_rid(self):
        self.samdb.transaction_start()
        try:
            rid = self.samdb.allocate_rid()
        except:
            self.samdb.transaction_cancel()
            raise
        self.samdb.transaction_commit()
        return str(rid)

    # Ensure that duplicate objectSID's are permitted for foreign security
    # principals.
    #
    def test_duplicate_objectSIDs_allowed_on_foreign_security_principals(self):

        #
        # We need to build a foreign security principal SID
        # i.e a  SID not in the current domain.
        #
        dom_sid = self.samdb.get_domain_sid()
        if str(dom_sid).endswith("0"):
            c = "9"
        else:
            c = "0"
        sid_str = str(dom_sid)[:-1] + c + "-1000"
        sid     = ndr_pack(security.dom_sid(sid_str))
        basedn  = self.samdb.get_default_basedn()
        dn      = "CN=%s,CN=ForeignSecurityPrincipals,%s" % (sid_str, basedn)

        #
        # First without control
        #

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"})
            self.fail("No exception should get ERR_OBJECT_CLASS_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_OBJECT_CLASS_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_MISSING_REQUIRED_ATT
            self.assertTrue(werr in msg, msg)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal",
                "objectSid": sid})
            self.fail("No exception should get ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_ILLEGAL_MOD_OPERATION
            self.assertTrue(werr in msg, msg)

        #
        # We need to use the provision control
        # in order to add foreignSecurityPrincipal
        # objects
        #

        controls = ["provision:0"]
        self.samdb.add({
            "dn": dn,
            "objectClass": "foreignSecurityPrincipal"},
            controls=controls)

        self.samdb.delete(dn)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"},
                controls=controls)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.fail("Got unexpected exception %d - %s "
                      % (code, msg))

        # cleanup
        self.samdb.delete(dn)

    def _test_foreignSecurityPrincipal(self, obj_class, fpo_attr):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn   = self.samdb.get_default_basedn()
        cn       = "dsdb_test_fpo"
        dn_str   = "cn=%s,cn=Users,%s" % (cn, basedn)
        dn = ldb.Dn(self.samdb, dn_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn_str)

        self.samdb.add({
            "dn": dn_str,
            "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_INVALID_GROUP_TYPE
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_NO_SUCH_OBJECT")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_NO_SUCH_OBJECT, str(e))
            werr = "%08X" % werror.WERR_NO_SUCH_MEMBER
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 1)
        self.samdb.delete(res[0].dn)
        self.samdb.delete(dn)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

    def test_foreignSecurityPrincipal_member(self):
        return self._test_foreignSecurityPrincipal(
                "group", "member")

    def test_foreignSecurityPrincipal_MembersForAzRole(self):
        return self._test_foreignSecurityPrincipal(
                "msDS-AzRole", "msDS-MembersForAzRole")

    def test_foreignSecurityPrincipal_NeverRevealGroup(self):
        return self._test_foreignSecurityPrincipal(
                "computer", "msDS-NeverRevealGroup")

    def test_foreignSecurityPrincipal_RevealOnDemandGroup(self):
        return self._test_foreignSecurityPrincipal(
                "computer", "msDS-RevealOnDemandGroup")

    def _test_fail_foreignSecurityPrincipal(self, obj_class, fpo_attr,
                                            msg_exp, lerr_exp, werr_exp,
                                            allow_reference=True):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn   = self.samdb.get_default_basedn()
        cn1       = "dsdb_test_fpo1"
        dn1_str   = "cn=%s,cn=Users,%s" % (cn1, basedn)
        dn1 = ldb.Dn(self.samdb, dn1_str)
        cn2       = "dsdb_test_fpo2"
        dn2_str   = "cn=%s,cn=Users,%s" % (cn2, basedn)
        dn2 = ldb.Dn(self.samdb, dn2_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn1_str)
        self.addCleanup(delete_force, self.samdb, dn2_str)

        self.samdb.add({
            "dn": dn1_str,
            "objectClass": obj_class})

        self.samdb.add({
            "dn": dn2_str,
            "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("%s" % dn2,
                                           ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            if not allow_reference:
                sel.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            if allow_reference:
                self.fail("Should have not raised an exception: %s" % e)
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(dn2)
        self.samdb.delete(dn1)

    def test_foreignSecurityPrincipal_NonMembers(self):
        return self._test_fail_foreignSecurityPrincipal(
                "group", "msDS-NonMembers",
                "LDB_ERR_UNWILLING_TO_PERFORM/WERR_NOT_SUPPORTED",
                ldb.ERR_UNWILLING_TO_PERFORM, werror.WERR_NOT_SUPPORTED,
                allow_reference=False)

    def test_foreignSecurityPrincipal_HostServiceAccount(self):
        return self._test_fail_foreignSecurityPrincipal(
                "computer", "msDS-HostServiceAccount",
                "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
                ldb.ERR_CONSTRAINT_VIOLATION,
                werror.WERR_DS_NAME_REFERENCE_INVALID)

    def test_foreignSecurityPrincipal_manager(self):
        return self._test_fail_foreignSecurityPrincipal(
                "user", "manager",
                "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
                ldb.ERR_CONSTRAINT_VIOLATION,
                werror.WERR_DS_NAME_REFERENCE_INVALID)

    #
    # Duplicate objectSID's should not be permitted for sids in the local
    # domain. The test sequence is add an object, delete it, then attempt to
    # re-add it, this should fail with a constraint violation
    #
    def test_duplicate_objectSIDs_not_allowed_on_local_objects(self):

        dom_sid = self.samdb.get_domain_sid()
        rid     = self.allocate_rid()
        sid_str = str(dom_sid) + "-" + rid
        sid     = ndr_pack(security.dom_sid(sid_str))
        basedn  = self.samdb.get_default_basedn()
        cn       = "dsdb_test_01"
        dn      = "cn=%s,cn=Users,%s" % (cn, basedn)

        self.samdb.add({
            "dn": dn,
            "objectClass": "user",
            "objectSID": sid})
        self.samdb.delete(dn)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "user",
                "objectSID": sid})
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            if code != ldb.ERR_CONSTRAINT_VIOLATION:
                self.fail("Got %d - %s should have got "
                          "LDB_ERR_CONSTRAINT_VIOLATION"
                          % (code, msg))

    def test_linked_vs_non_linked_reference(self):
        basedn   = self.samdb.get_default_basedn()
        kept_dn_str   = "cn=reference_kept,cn=Users,%s" % (basedn)
        removed_dn_str   = "cn=reference_removed,cn=Users,%s" % (basedn)
        dom_sid = self.samdb.get_domain_sid()
        none_sid_str = str(dom_sid) + "-4294967294"
        none_guid_str = "afafafaf-fafa-afaf-fafa-afafafafafaf"

        self.addCleanup(delete_force, self.samdb, kept_dn_str)
        self.addCleanup(delete_force, self.samdb, removed_dn_str)

        self.samdb.add({
            "dn": kept_dn_str,
            "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=kept_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        kept_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        kept_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        kept_dn = res[0].dn

        self.samdb.add({
            "dn": removed_dn_str,
            "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=removed_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        removed_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        removed_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        self.samdb.delete(removed_dn_str)

        #
        # First try the linked attribute 'manager'
        # by GUID and SID
        #

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                           ldb.FLAG_MOD_ADD,
                                           "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                           ldb.FLAG_MOD_ADD,
                                           "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        #
        # Try the non-linked attribute 'assistant'
        # by GUID and SID, which should work.
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_DELETE,
                                              "assistant")
        self.samdb.modify(msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_DELETE,
                                              "assistant")
        self.samdb.modify(msg)

        #
        # Finally ry the non-linked attribute 'assistant'
        # but with non existing GUID, SID, DN
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("CN=NoneNone,%s" % (basedn),
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % none_sid_str,
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % none_guid_str,
                                              ldb.FLAG_MOD_ADD,
                                              "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(kept_dn)

    def test_normalize_dn_in_domain_full(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        full_str = str(full_dn)

        # That is, no change
        self.assertEqual(full_dn,
                         self.samdb.normalize_dn_in_domain(full_str))

    def test_normalize_dn_in_domain_part(self):
        domain_dn = self.samdb.domain_dn()

        part_str = "CN=Users"

        full_dn = ldb.Dn(self.samdb, part_str)
        full_dn.add_base(domain_dn)

        # That is, the domain DN appended
        self.assertEqual(full_dn,
                         self.samdb.normalize_dn_in_domain(part_str))

    def test_normalize_dn_in_domain_full_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        # That is, no change
        self.assertEqual(full_dn,
                         self.samdb.normalize_dn_in_domain(full_dn))

    def test_normalize_dn_in_domain_part_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        # That is, the domain DN appended
        self.assertEqual(ldb.Dn(self.samdb,
                                str(part_dn) + "," + str(domain_dn)),
                         self.samdb.normalize_dn_in_domain(part_dn))
Exemple #24
0
class VLVTests(samba.tests.TestCase):

    def create_user(self, i, n, prefix='vlvtest', suffix='', attrs=None):
        name = "%s%d%s" % (prefix, i, suffix)
        user = {
            'cn': name,
            "objectclass": "user",
            'givenName': "abcdefghijklmnopqrstuvwxyz"[i % 26],
            "roomNumber": "%sbc" % (n - i),
            "carLicense": "后来经",
            "employeeNumber": "%s%sx" % (abs(i * (99 - i)), '\n' * (i & 255)),
            "accountExpires": "%s" % (10 ** 9 + 1000000 * i),
            "msTSExpireDate4": "19%02d0101010000.0Z" % (i % 100),
            "flags": str(i * (n - i)),
            "serialNumber": "abc %s%s%s" % ('AaBb |-/'[i & 7],
                                            ' 3z}'[i & 3],
                                            '"@'[i & 1],),
        }

        # _user_broken_attrs tests are broken due to problems outside
        # of VLV.
        _user_broken_attrs = {
            # Sort doesn't look past a NUL byte.
            "photo": "\x00%d" % (n - i),
            "audio": "%sn octet string %s%s ♫♬\x00lalala" % ('Aa'[i & 1],
                                                             chr(i & 255), i),
            "displayNamePrintable": "%d\x00%c" % (i, i & 255),
            "adminDisplayName": "%d\x00b" % (n-i),
            "title": "%d%sb" % (n - i, '\x00' * i),
            "comment": "Favourite colour is %d" % (n % (i + 1)),

            # Names that vary only in case. Windows returns
            # equivalent addresses in the order they were put
            # in ('a st', 'A st',...).
            "street": "%s st" % (chr(65 | (i & 14) | ((i & 1) * 32))),
        }

        if attrs is not None:
            user.update(attrs)

        user['dn'] = "cn=%s,%s" % (user['cn'], self.ou)

        if opts.skip_attr_regex:
            match = re.compile(opts.skip_attr_regex).search
            for k in user.keys():
                if match(k):
                    del user[k]

        self.users.append(user)
        self.ldb.add(user)
        return user

    def setUp(self):
        super(VLVTests, self).setUp()
        self.ldb = SamDB(host, credentials=creds,
                         session_info=system_session(lp), lp=lp)

        self.base_dn = self.ldb.domain_dn()
        self.ou = "ou=vlv,%s" % self.base_dn
        if opts.delete_in_setup:
            try:
                self.ldb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError, e:
                print "tried deleting %s, got error %s" % (self.ou, e)
        self.ldb.add({
            "dn": self.ou,
            "objectclass": "organizationalUnit"})

        self.users = []
        for i in range(N_ELEMENTS):
            self.create_user(i, N_ELEMENTS)

        attrs = self.users[0].keys()
        self.binary_sorted_keys = ['audio',
                                   'photo',
                                   "msTSExpireDate4",
                                   'serialNumber',
                                   "displayNamePrintable"]

        self.numeric_sorted_keys = ['flags',
                                    'accountExpires']

        self.timestamp_keys = ['msTSExpireDate4']

        self.int64_keys = set(['accountExpires'])

        self.locale_sorted_keys = [x for x in attrs if
                                   x not in (self.binary_sorted_keys +
                                             self.numeric_sorted_keys)]

        # don't try spaces, etc in cn
        self.delicate_keys = ['cn']
class UserTests(samba.tests.TestCase):

    def add_if_possible(self, *args, **kwargs):
        """In these tests sometimes things are left in the database
        deliberately, so we don't worry if we fail to add them a second
        time."""
        try:
            self.ldb.add(*args, **kwargs)
        except LdbError:
            pass

    def setUp(self):
        super(UserTests, self).setUp()
        self.state = GlobalState  # the class itself, not an instance
        self.lp = lp
        self.ldb = SamDB(host, credentials=creds,
                         session_info=system_session(lp), lp=lp)
        self.base_dn = self.ldb.domain_dn()
        self.ou = "OU=pid%s,%s" % (os.getpid(), self.base_dn)
        self.ou_users = "OU=users,%s" % self.ou
        self.ou_groups = "OU=groups,%s" % self.ou
        self.ou_computers = "OU=computers,%s" % self.ou

        for dn in (self.ou, self.ou_users, self.ou_groups,
                   self.ou_computers):
            self.add_if_possible({
                "dn": dn,
                "objectclass": "organizationalUnit"})

    def tearDown(self):
        super(UserTests, self).tearDown()

    def test_00_00_do_nothing(self):
        # this gives us an idea of the overhead
        pass

    def _prepare_n_groups(self, n):
        self.state.n_groups = n
        for i in range(n):
            self.add_if_possible({
                "dn": "cn=g%d,%s" % (i, self.ou_groups),
                "objectclass": "group"})

    def _add_users(self, start, end):
        for i in range(start, end):
            self.ldb.add({
                "dn": "cn=u%d,%s" % (i, self.ou_users),
                "objectclass": "user"})

    def _test_join(self):
        tmpdir = tempfile.mkdtemp()
        if '://' in host:
            server = host.split('://', 1)[1]
        else:
            server = host
        cmd = cmd_sambatool.subcommands['domain'].subcommands['join']
        result = cmd._run("samba-tool domain join",
                          creds.get_realm(),
                          "dc", "-U%s%%%s" % (creds.get_username(),
                                              creds.get_password()),
                          '--targetdir=%s' % tmpdir,
                          '--server=%s' % server)

        shutil.rmtree(tmpdir)

    def _test_unindexed_search(self):
        expressions = [
            ('(&(objectclass=user)(description='
             'Built-in account for adminstering the computer/domain))'),
            '(description=Built-in account for adminstering the computer/domain)',
            '(objectCategory=*)',
            '(samaccountname=Administrator*)'
        ]
        for expression in expressions:
            t = time.time()
            for i in range(10):
                self.ldb.search(self.ou,
                                expression=expression,
                                scope=SCOPE_SUBTREE,
                                attrs=['cn'])
            print('%d %s took %s' % (i, expression,
                                                    time.time() - t), file=sys.stderr)

    def _test_indexed_search(self):
        expressions = ['(objectclass=group)',
                       '(samaccountname=Administrator)'
        ]
        for expression in expressions:
            t = time.time()
            for i in range(100):
                self.ldb.search(self.ou,
                                expression=expression,
                                scope=SCOPE_SUBTREE,
                                attrs=['cn'])
            print('%d runs %s took %s' % (i, expression,
                                                         time.time() - t), file=sys.stderr)

    def _test_add_many_users(self, n=BATCH_SIZE):
        s = self.state.next_user_id
        e = s + n
        self._add_users(s, e)
        self.state.next_user_id = e

    test_00_00_join_empty_dc = _test_join

    test_00_01_adding_users_1000 = _test_add_many_users
    test_00_02_adding_users_2000 = _test_add_many_users
    test_00_03_adding_users_3000 = _test_add_many_users

    test_00_10_join_unlinked_dc = _test_join
    test_00_11_unindexed_search_3k_users = _test_unindexed_search
    test_00_12_indexed_search_3k_users = _test_indexed_search

    def _link_user_and_group(self, u, g):
        m = Message()
        m.dn = Dn(self.ldb, "CN=g%d,%s" % (g, self.ou_groups))
        m["member"] = MessageElement("cn=u%d,%s" % (u, self.ou_users),
                                     FLAG_MOD_ADD, "member")
        self.ldb.modify(m)

    def _unlink_user_and_group(self, u, g):
        user = "******" % (u, self.ou_users)
        group = "CN=g%d,%s" % (g, self.ou_groups)
        m = Message()
        m.dn = Dn(self.ldb, group)
        m["member"] = MessageElement(user, FLAG_MOD_DELETE, "member")
        self.ldb.modify(m)

    def _test_link_many_users(self, n=BATCH_SIZE):
        self._prepare_n_groups(N_GROUPS)
        s = self.state.next_linked_user
        e = s + n
        for i in range(s, e):
            g = i % N_GROUPS
            self._link_user_and_group(i, g)
        self.state.next_linked_user = e

    test_01_01_link_users_1000 = _test_link_many_users
    test_01_02_link_users_2000 = _test_link_many_users
    test_01_03_link_users_3000 = _test_link_many_users

    def _test_link_many_users_offset_1(self, n=BATCH_SIZE):
        s = self.state.next_relinked_user
        e = s + n
        for i in range(s, e):
            g = (i + 1) % N_GROUPS
            self._link_user_and_group(i, g)
        self.state.next_relinked_user = e

    test_02_01_link_users_again_1000 = _test_link_many_users_offset_1
    test_02_02_link_users_again_2000 = _test_link_many_users_offset_1
    test_02_03_link_users_again_3000 = _test_link_many_users_offset_1

    test_02_10_join_partially_linked_dc = _test_join
    test_02_11_unindexed_search_partially_linked_dc = _test_unindexed_search
    test_02_12_indexed_search_partially_linked_dc = _test_indexed_search

    def _test_link_many_users_3_groups(self, n=BATCH_SIZE, groups=3):
        s = self.state.next_linked_user_3
        e = s + n
        self.state.next_linked_user_3 = e
        for i in range(s, e):
            g = (i + 2) % groups
            if g not in (i % N_GROUPS, (i + 1) % N_GROUPS):
                self._link_user_and_group(i, g)

    test_03_01_link_users_again_1000_few_groups = _test_link_many_users_3_groups
    test_03_02_link_users_again_2000_few_groups = _test_link_many_users_3_groups
    test_03_03_link_users_again_3000_few_groups = _test_link_many_users_3_groups

    def _test_remove_links_0(self, n=BATCH_SIZE):
        s = self.state.next_removed_link_0
        e = s + n
        self.state.next_removed_link_0 = e
        for i in range(s, e):
            g = i % N_GROUPS
            self._unlink_user_and_group(i, g)

    test_04_01_remove_some_links_1000 = _test_remove_links_0
    test_04_02_remove_some_links_2000 = _test_remove_links_0
    test_04_03_remove_some_links_3000 = _test_remove_links_0

    # back to using _test_add_many_users
    test_05_01_adding_users_after_links_4000 = _test_add_many_users

    # reset the link count, to replace the original links
    def test_06_01_relink_users_1000(self):
        self.state.next_linked_user = 0
        self._test_link_many_users()

    test_06_02_link_users_2000 = _test_link_many_users
    test_06_03_link_users_3000 = _test_link_many_users
    test_06_04_link_users_4000 = _test_link_many_users
    test_06_05_link_users_again_4000 = _test_link_many_users_offset_1
    test_06_06_link_users_again_4000_few_groups = _test_link_many_users_3_groups

    test_07_01_adding_users_after_links_5000 = _test_add_many_users

    def _test_link_random_users_and_groups(self, n=BATCH_SIZE, groups=100):
        self._prepare_n_groups(groups)
        for i in range(n):
            u = random.randrange(self.state.next_user_id)
            g = random.randrange(groups)
            try:
                self._link_user_and_group(u, g)
            except LdbError:
                pass

    test_08_01_link_random_users_100_groups = _test_link_random_users_and_groups
    test_08_02_link_random_users_100_groups = _test_link_random_users_and_groups

    test_10_01_unindexed_search_full_dc = _test_unindexed_search
    test_10_02_indexed_search_full_dc = _test_indexed_search
    test_11_02_join_full_dc = _test_join

    def test_20_01_delete_50_groups(self):
        for i in range(self.state.n_groups - 50, self.state.n_groups):
            self.ldb.delete("cn=g%d,%s" % (i, self.ou_groups))
        self.state.n_groups -= 50

    def _test_delete_many_users(self, n=BATCH_SIZE):
        e = self.state.next_user_id
        s = max(0, e - n)
        self.state.next_user_id = s
        for i in range(s, e):
            self.ldb.delete("cn=u%d,%s" % (i, self.ou_users))

    test_21_01_delete_users_5000_lightly_linked = _test_delete_many_users
    test_21_02_delete_users_4000_lightly_linked = _test_delete_many_users
    test_21_03_delete_users_3000 = _test_delete_many_users

    def test_22_01_delete_all_groups(self):
        for i in range(self.state.n_groups):
            self.ldb.delete("cn=g%d,%s" % (i, self.ou_groups))
        self.state.n_groups = 0

    test_23_01_delete_users_after_groups_2000 = _test_delete_many_users
    test_23_00_delete_users_after_groups_1000 = _test_delete_many_users

    test_24_02_join_after_cleanup = _test_join
Exemple #26
0
class DsdbTests(TestCase):
    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

        # Create a test user
        user_name = "dsdb-user-" + str(uuid.uuid4().hex[0:6])
        user_pass = samba.generate_random_password(32, 32)
        user_description = "Test user for dsdb test"

        base_dn = self.samdb.domain_dn()

        self.account_dn = "CN=" + user_name + ",CN=Users," + base_dn
        self.samdb.newuser(username=user_name,
                           password=user_pass,
                           description=user_description)
        # Cleanup (teardown)
        self.addCleanup(delete_force, self.samdb, self.account_dn)

        # Get server reference DN
        res = self.samdb.search(base=ldb.Dn(self.samdb,
                                            self.samdb.get_serverName()),
                                scope=ldb.SCOPE_BASE,
                                attrs=["serverReference"])
        # Get server reference
        self.server_ref_dn = ldb.Dn(
            self.samdb, res[0]["serverReference"][0].decode("utf-8"))

        # Get RID Set DN
        res = self.samdb.search(base=self.server_ref_dn,
                                scope=ldb.SCOPE_BASE,
                                attrs=["rIDSetReferences"])
        rid_set_refs = res[0]
        self.assertIn("rIDSetReferences", rid_set_refs)
        rid_set_str = rid_set_refs["rIDSetReferences"][0].decode("utf-8")
        self.rid_set_dn = ldb.Dn(self.samdb, rid_set_str)

    def get_rid_set(self, rid_set_dn):
        res = self.samdb.search(base=rid_set_dn,
                                scope=ldb.SCOPE_BASE,
                                attrs=[
                                    "rIDAllocationPool",
                                    "rIDPreviousAllocationPool", "rIDUsedPool",
                                    "rIDNextRID"
                                ])
        return res[0]

    def test_ridalloc_next_free_rid(self):
        # Test RID allocation. We assume that RID
        # pools allocated to us are continguous.
        self.samdb.transaction_start()
        try:
            orig_rid_set = self.get_rid_set(self.rid_set_dn)
            self.assertIn("rIDAllocationPool", orig_rid_set)
            self.assertIn("rIDPreviousAllocationPool", orig_rid_set)
            self.assertIn("rIDUsedPool", orig_rid_set)
            self.assertIn("rIDNextRID", orig_rid_set)

            # Get rIDNextRID value from RID set.
            next_rid = int(orig_rid_set["rIDNextRID"][0])

            # Check the result of next_free_rid().
            next_free_rid = self.samdb.next_free_rid()
            self.assertEqual(next_rid + 1, next_free_rid)

            # Check calling it twice in succession gives the same result.
            next_free_rid2 = self.samdb.next_free_rid()
            self.assertEqual(next_free_rid, next_free_rid2)

            # Ensure that the RID set attributes have not changed.
            rid_set2 = self.get_rid_set(self.rid_set_dn)
            self.assertEqual(orig_rid_set, rid_set2)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_no_ridnextrid(self):
        self.samdb.transaction_start()
        try:
            # Delete the rIDNextRID attribute of the RID set,
            # and set up previous and next pools.
            prev_lo = 1000
            prev_hi = 1999
            next_lo = 3000
            next_hi = 3999
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDNextRID"] = ldb.MessageElement([], ldb.FLAG_MOD_DELETE,
                                                   "rIDNextRID")
            msg["rIDPreviousAllocationPool"] = (ldb.MessageElement(
                str((prev_hi << 32) | prev_lo), ldb.FLAG_MOD_REPLACE,
                "rIDPreviousAllocationPool"))
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                str((next_hi << 32) | next_lo), ldb.FLAG_MOD_REPLACE,
                "rIDAllocationPool"))
            self.samdb.modify(msg)

            # Ensure that next_free_rid() returns the start of the next pool.
            next_free_rid3 = self.samdb.next_free_rid()
            self.assertEqual(next_lo, next_free_rid3)

            # Check the result of allocate_rid() matches.
            rid = self.samdb.allocate_rid()
            self.assertEqual(next_free_rid3, rid)

            # Check that the result of next_free_rid() has now changed.
            next_free_rid4 = self.samdb.next_free_rid()
            self.assertEqual(rid + 1, next_free_rid4)

            # Check the range of available RIDs.
            free_lo, free_hi = self.samdb.free_rid_bounds()
            self.assertEqual(rid + 1, free_lo)
            self.assertEqual(next_hi, free_hi)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_no_free_rids(self):
        self.samdb.transaction_start()
        try:
            # Exhaust our current pool of RIDs.
            pool_lo = 2000
            pool_hi = 2999
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDPreviousAllocationPool"] = (ldb.MessageElement(
                str((pool_hi << 32) | pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDPreviousAllocationPool"))
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                str((pool_hi << 32) | pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDAllocationPool"))
            msg["rIDNextRID"] = (ldb.MessageElement(str(pool_hi),
                                                    ldb.FLAG_MOD_REPLACE,
                                                    "rIDNextRID"))
            self.samdb.modify(msg)

            # Ensure that calculating the next free RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.next_free_rid()

            self.assertEqual("RID pools out of RIDs", err.exception.args[1])

            # Ensure we can still allocate a new RID.
            self.samdb.allocate_rid()
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_new_ridset(self):
        self.samdb.transaction_start()
        try:
            # Test what happens with RID Set values set to zero (similar to
            # when a RID Set is first created, except we also set
            # rIDAllocationPool to zero).
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDPreviousAllocationPool"] = (ldb.MessageElement(
                "0", ldb.FLAG_MOD_REPLACE, "rIDPreviousAllocationPool"))
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                "0", ldb.FLAG_MOD_REPLACE, "rIDAllocationPool"))
            msg["rIDNextRID"] = (ldb.MessageElement("0", ldb.FLAG_MOD_REPLACE,
                                                    "rIDNextRID"))
            self.samdb.modify(msg)

            # Ensure that calculating the next free RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.next_free_rid()

            self.assertEqual("RID pools out of RIDs", err.exception.args[1])

            # Set values for the next pool.
            pool_lo = 2000
            pool_hi = 2999
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                str((pool_hi << 32) | pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDAllocationPool"))
            self.samdb.modify(msg)

            # Ensure the next free RID value is equal to the next pool's lower
            # bound.
            next_free_rid5 = self.samdb.next_free_rid()
            self.assertEqual(pool_lo, next_free_rid5)

            # Check the range of available RIDs.
            free_lo, free_hi = self.samdb.free_rid_bounds()
            self.assertEqual(pool_lo, free_lo)
            self.assertEqual(pool_hi, free_hi)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_move_to_new_pool(self):
        self.samdb.transaction_start()
        try:
            # Test moving to a new pool from the previous pool.
            pool_lo = 2000
            pool_hi = 2999
            new_pool_lo = 4500
            new_pool_hi = 4599
            msg = ldb.Message()
            msg.dn = self.rid_set_dn
            msg["rIDPreviousAllocationPool"] = (ldb.MessageElement(
                str((pool_hi << 32) | pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDPreviousAllocationPool"))
            msg["rIDAllocationPool"] = (ldb.MessageElement(
                str((new_pool_hi << 32) | new_pool_lo), ldb.FLAG_MOD_REPLACE,
                "rIDAllocationPool"))
            msg["rIDNextRID"] = (ldb.MessageElement(str(pool_hi - 1),
                                                    ldb.FLAG_MOD_REPLACE,
                                                    "rIDNextRID"))
            self.samdb.modify(msg)

            # We should have remained in the previous pool.
            next_free_rid6 = self.samdb.next_free_rid()
            self.assertEqual(pool_hi, next_free_rid6)

            # Check the range of available RIDs.
            free_lo, free_hi = self.samdb.free_rid_bounds()
            self.assertEqual(pool_hi, free_lo)
            self.assertEqual(pool_hi, free_hi)

            # Allocate a new RID.
            rid2 = self.samdb.allocate_rid()
            self.assertEqual(next_free_rid6, rid2)

            # We should now move to the next pool.
            next_free_rid7 = self.samdb.next_free_rid()
            self.assertEqual(new_pool_lo, next_free_rid7)

            # Check the new range of available RIDs.
            free_lo2, free_hi2 = self.samdb.free_rid_bounds()
            self.assertEqual(new_pool_lo, free_lo2)
            self.assertEqual(new_pool_hi, free_hi2)

            # Ensure that allocate_rid() matches.
            rid3 = self.samdb.allocate_rid()
            self.assertEqual(next_free_rid7, rid3)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_no_ridsetreferences(self):
        self.samdb.transaction_start()
        try:
            # Delete the rIDSetReferences attribute.
            msg = ldb.Message()
            msg.dn = self.server_ref_dn
            msg["rIDSetReferences"] = (ldb.MessageElement([],
                                                          ldb.FLAG_MOD_DELETE,
                                                          "rIDSetReferences"))
            self.samdb.modify(msg)

            # Ensure calculating the next free RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.next_free_rid()

            enum, estr = err.exception.args
            self.assertEqual(ldb.ERR_NO_SUCH_ATTRIBUTE, enum)
            self.assertIn(
                "No RID Set DN - "
                "Cannot find attribute rIDSetReferences of %s "
                "to calculate reference dn" % self.server_ref_dn, estr)

            # Ensure allocating a new RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.allocate_rid()

            enum, estr = err.exception.args
            self.assertEqual(ldb.ERR_ENTRY_ALREADY_EXISTS, enum)
            self.assertIn(
                "No RID Set DN - "
                "Failed to add RID Set %s - "
                "Entry %s already exists" % (self.rid_set_dn, self.rid_set_dn),
                estr)
        finally:
            self.samdb.transaction_cancel()

    def test_ridalloc_no_rid_set(self):
        self.samdb.transaction_start()
        try:
            # Set the rIDSetReferences attribute to not point to a RID Set.
            fake_rid_set_str = self.account_dn
            msg = ldb.Message()
            msg.dn = self.server_ref_dn
            msg["rIDSetReferences"] = (ldb.MessageElement(
                fake_rid_set_str, ldb.FLAG_MOD_REPLACE, "rIDSetReferences"))
            self.samdb.modify(msg)

            # Ensure calculating the next free RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.next_free_rid()

            enum, estr = err.exception.args
            self.assertEqual(ldb.ERR_OPERATIONS_ERROR, enum)
            self.assertIn("Bad RID Set " + fake_rid_set_str, estr)

            # Ensure allocating a new RID fails.
            with self.assertRaises(ldb.LdbError) as err:
                self.samdb.allocate_rid()

            enum, estr = err.exception.args
            self.assertEqual(ldb.ERR_OPERATIONS_ERROR, enum)
            self.assertIn("Bad RID Set " + fake_rid_set_str, estr)
        finally:
            self.samdb.transaction_cancel()

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEqual(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, [
            "local_oid:1.3.6.1.4.1.7165.4.3.14:0",
            "local_oid:1.3.6.1.4.1.7165.4.3.25:0"
        ])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val",
                                                ldb.FLAG_MOD_REPLACE,
                                                "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          res[0]["replPropertyMetaData"][0])
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = int(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = int(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEqual(self.samdb.get_attribute_from_attid(13),
                         "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEqual(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEqual(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 2)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=self.account_dn,
                                attrs=["dn"])
        self.assertEqual(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(
            dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description",
                                                      version + 2)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "description"),
            version + 2)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:0" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(
                scope=ldb.SCOPE_SUBTREE,
                base=self.account_dn,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:1" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            (errno, estr) = e.args
            if errno != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail(
                    "Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                    % e[1])

    # Allocate a unique RID for use in the objectSID tests.
    #
    def allocate_rid(self):
        self.samdb.transaction_start()
        try:
            rid = self.samdb.allocate_rid()
        except:
            self.samdb.transaction_cancel()
            raise
        self.samdb.transaction_commit()
        return str(rid)

    # Ensure that duplicate objectSID's are permitted for foreign security
    # principals.
    #
    def test_duplicate_objectSIDs_allowed_on_foreign_security_principals(self):

        #
        # We need to build a foreign security principal SID
        # i.e a  SID not in the current domain.
        #
        dom_sid = self.samdb.get_domain_sid()
        if str(dom_sid).endswith("0"):
            c = "9"
        else:
            c = "0"
        sid_str = str(dom_sid)[:-1] + c + "-1000"
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        dn = "CN=%s,CN=ForeignSecurityPrincipals,%s" % (sid_str, basedn)

        #
        # First without control
        #

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal"
            })
            self.fail("No exception should get ERR_OBJECT_CLASS_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_OBJECT_CLASS_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_MISSING_REQUIRED_ATT
            self.assertTrue(werr in msg, msg)

        try:
            self.samdb.add({
                "dn": dn,
                "objectClass": "foreignSecurityPrincipal",
                "objectSid": sid
            })
            self.fail("No exception should get ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_ILLEGAL_MOD_OPERATION
            self.assertTrue(werr in msg, msg)

        #
        # We need to use the provision control
        # in order to add foreignSecurityPrincipal
        # objects
        #

        controls = ["provision:0"]
        self.samdb.add({
            "dn": dn,
            "objectClass": "foreignSecurityPrincipal"
        },
                       controls=controls)

        self.samdb.delete(dn)

        try:
            self.samdb.add(
                {
                    "dn": dn,
                    "objectClass": "foreignSecurityPrincipal"
                },
                controls=controls)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.fail("Got unexpected exception %d - %s " % (code, msg))

        # cleanup
        self.samdb.delete(dn)

    def _test_foreignSecurityPrincipal(self, obj_class, fpo_attr):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_fpo"
        dn_str = "cn=%s,cn=Users,%s" % (cn, basedn)
        dn = ldb.Dn(self.samdb, dn_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn_str)

        self.samdb.add({"dn": dn_str, "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_UNWILLING_TO_PERFORM")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_UNWILLING_TO_PERFORM, str(e))
            werr = "%08X" % werror.WERR_DS_INVALID_GROUP_TYPE
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_NO_SUCH_OBJECT")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_NO_SUCH_OBJECT, str(e))
            werr = "%08X" % werror.WERR_NO_SUCH_MEMBER
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 1)
        self.samdb.delete(res[0].dn)
        self.samdb.delete(dn)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

    def test_foreignSecurityPrincipal_member(self):
        return self._test_foreignSecurityPrincipal("group", "member")

    def test_foreignSecurityPrincipal_MembersForAzRole(self):
        return self._test_foreignSecurityPrincipal("msDS-AzRole",
                                                   "msDS-MembersForAzRole")

    def test_foreignSecurityPrincipal_NeverRevealGroup(self):
        return self._test_foreignSecurityPrincipal("computer",
                                                   "msDS-NeverRevealGroup")

    def test_foreignSecurityPrincipal_RevealOnDemandGroup(self):
        return self._test_foreignSecurityPrincipal("computer",
                                                   "msDS-RevealOnDemandGroup")

    def _test_fail_foreignSecurityPrincipal(self,
                                            obj_class,
                                            fpo_attr,
                                            msg_exp,
                                            lerr_exp,
                                            werr_exp,
                                            allow_reference=True):

        dom_sid = self.samdb.get_domain_sid()
        lsid_str = str(dom_sid) + "-4294967294"
        bsid_str = "S-1-5-32-4294967294"
        fsid_str = "S-1-5-4294967294"
        basedn = self.samdb.get_default_basedn()
        cn1 = "dsdb_test_fpo1"
        dn1_str = "cn=%s,cn=Users,%s" % (cn1, basedn)
        dn1 = ldb.Dn(self.samdb, dn1_str)
        cn2 = "dsdb_test_fpo2"
        dn2_str = "cn=%s,cn=Users,%s" % (cn2, basedn)
        dn2 = ldb.Dn(self.samdb, dn2_str)

        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % lsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % bsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=basedn,
                                expression="(objectSid=%s)" % fsid_str,
                                attrs=[])
        self.assertEqual(len(res), 0)

        self.addCleanup(delete_force, self.samdb, dn1_str)
        self.addCleanup(delete_force, self.samdb, dn2_str)

        self.samdb.add({"dn": dn1_str, "objectClass": obj_class})

        self.samdb.add({"dn": dn2_str, "objectClass": obj_class})

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % lsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % bsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("<SID=%s>" % fsid_str,
                                           ldb.FLAG_MOD_ADD, fpo_attr)
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get %s" % msg)
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = dn1
        msg[fpo_attr] = ldb.MessageElement("%s" % dn2, ldb.FLAG_MOD_ADD,
                                           fpo_attr)
        try:
            self.samdb.modify(msg)
            if not allow_reference:
                self.fail("No exception should get %s" % msg_exp)
        except ldb.LdbError as e:
            if allow_reference:
                self.fail("Should have not raised an exception: %s" % e)
            (code, msg) = e.args
            self.assertEqual(code, lerr_exp, str(e))
            werr = "%08X" % werr_exp
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(dn2)
        self.samdb.delete(dn1)

    def test_foreignSecurityPrincipal_NonMembers(self):
        return self._test_fail_foreignSecurityPrincipal(
            "group",
            "msDS-NonMembers",
            "LDB_ERR_UNWILLING_TO_PERFORM/WERR_NOT_SUPPORTED",
            ldb.ERR_UNWILLING_TO_PERFORM,
            werror.WERR_NOT_SUPPORTED,
            allow_reference=False)

    def test_foreignSecurityPrincipal_HostServiceAccount(self):
        return self._test_fail_foreignSecurityPrincipal(
            "computer", "msDS-HostServiceAccount",
            "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
            ldb.ERR_CONSTRAINT_VIOLATION,
            werror.WERR_DS_NAME_REFERENCE_INVALID)

    def test_foreignSecurityPrincipal_manager(self):
        return self._test_fail_foreignSecurityPrincipal(
            "user", "manager",
            "LDB_ERR_CONSTRAINT_VIOLATION/WERR_DS_NAME_REFERENCE_INVALID",
            ldb.ERR_CONSTRAINT_VIOLATION,
            werror.WERR_DS_NAME_REFERENCE_INVALID)

    #
    # Duplicate objectSID's should not be permitted for sids in the local
    # domain. The test sequence is add an object, delete it, then attempt to
    # re-add it, this should fail with a constraint violation
    #
    def test_duplicate_objectSIDs_not_allowed_on_local_objects(self):

        dom_sid = self.samdb.get_domain_sid()
        rid = self.allocate_rid()
        sid_str = str(dom_sid) + "-" + rid
        sid = ndr_pack(security.dom_sid(sid_str))
        basedn = self.samdb.get_default_basedn()
        cn = "dsdb_test_01"
        dn = "cn=%s,cn=Users,%s" % (cn, basedn)

        self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
        self.samdb.delete(dn)

        try:
            self.samdb.add({"dn": dn, "objectClass": "user", "objectSID": sid})
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            if code != ldb.ERR_CONSTRAINT_VIOLATION:
                self.fail("Got %d - %s should have got "
                          "LDB_ERR_CONSTRAINT_VIOLATION" % (code, msg))

    def test_linked_vs_non_linked_reference(self):
        basedn = self.samdb.get_default_basedn()
        kept_dn_str = "cn=reference_kept,cn=Users,%s" % (basedn)
        removed_dn_str = "cn=reference_removed,cn=Users,%s" % (basedn)
        dom_sid = self.samdb.get_domain_sid()
        none_sid_str = str(dom_sid) + "-4294967294"
        none_guid_str = "afafafaf-fafa-afaf-fafa-afafafafafaf"

        self.addCleanup(delete_force, self.samdb, kept_dn_str)
        self.addCleanup(delete_force, self.samdb, removed_dn_str)

        self.samdb.add({"dn": kept_dn_str, "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=kept_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        kept_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        kept_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        kept_dn = res[0].dn

        self.samdb.add({"dn": removed_dn_str, "objectClass": "user"})
        res = self.samdb.search(scope=ldb.SCOPE_SUBTREE,
                                base=removed_dn_str,
                                attrs=["objectGUID", "objectSID"])
        self.assertEqual(len(res), 1)
        removed_guid = ndr_unpack(misc.GUID, res[0]["objectGUID"][0])
        removed_sid = ndr_unpack(security.dom_sid, res[0]["objectSid"][0])
        self.samdb.delete(removed_dn_str)

        #
        # First try the linked attribute 'manager'
        # by GUID and SID
        #

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                            ldb.FLAG_MOD_ADD, "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["manager"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                            ldb.FLAG_MOD_ADD, "manager")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        #
        # Try the non-linked attribute 'assistant'
        # by GUID and SID, which should work.
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_ADD, "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % removed_sid,
                                              ldb.FLAG_MOD_DELETE, "assistant")
        self.samdb.modify(msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_ADD, "assistant")
        self.samdb.modify(msg)
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % removed_guid,
                                              ldb.FLAG_MOD_DELETE, "assistant")
        self.samdb.modify(msg)

        #
        # Finally ry the non-linked attribute 'assistant'
        # but with non existing GUID, SID, DN
        #
        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("CN=NoneNone,%s" % (basedn),
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<SID=%s>" % none_sid_str,
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        msg = ldb.Message()
        msg.dn = kept_dn
        msg["assistant"] = ldb.MessageElement("<GUID=%s>" % none_guid_str,
                                              ldb.FLAG_MOD_ADD, "assistant")
        try:
            self.samdb.modify(msg)
            self.fail("No exception should get LDB_ERR_CONSTRAINT_VIOLATION")
        except ldb.LdbError as e:
            (code, msg) = e.args
            self.assertEqual(code, ldb.ERR_CONSTRAINT_VIOLATION, str(e))
            werr = "%08X" % werror.WERR_DS_NAME_REFERENCE_INVALID
            self.assertTrue(werr in msg, msg)

        self.samdb.delete(kept_dn)

    def test_normalize_dn_in_domain_full(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        full_str = str(full_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_str))

    def test_normalize_dn_in_domain_part(self):
        domain_dn = self.samdb.domain_dn()

        part_str = "CN=Users"

        full_dn = ldb.Dn(self.samdb, part_str)
        full_dn.add_base(domain_dn)

        # That is, the domain DN appended
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(part_str))

    def test_normalize_dn_in_domain_full_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        full_dn = part_dn
        full_dn.add_base(domain_dn)

        # That is, no change
        self.assertEqual(full_dn, self.samdb.normalize_dn_in_domain(full_dn))

    def test_normalize_dn_in_domain_part_dn(self):
        domain_dn = self.samdb.domain_dn()

        part_dn = ldb.Dn(self.samdb, "CN=Users")

        # That is, the domain DN appended
        self.assertEqual(
            ldb.Dn(self.samdb,
                   str(part_dn) + "," + str(domain_dn)),
            self.samdb.normalize_dn_in_domain(part_dn))
Exemple #27
0
class DsdbLockTestCase(SamDBTestCase):
    def test_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session, lp=self.lp)

            self.samdb.transaction_start()

            dn = "cn=test_db_lock_user,cn=users," + str(basedn)
            self.samdb.add({
                "dn": dn,
                "objectclass": "user",
            })
            self.samdb.delete(dn)

            # Obtain a write lock
            self.samdb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            self.samdb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        # Release the locks
        for l in res:
            pass

        end = time.time()
        self.assertGreater(end - start, 1.9)

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session, lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        dn = "cn=test_db_lock_user,cn=users," + str(basedn)
        self.samdb.add({
            "dn": dn,
            "objectclass": "user",
        })
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the parent releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            self.samdb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock3(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session, lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")

        # This will end up in the top level db
        dn = "@DSDB_LOCK_TEST"
        self.samdb.add({"dn": dn})
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        self.assertGreater(end - start, 1.9)
        os.write(w1, b"prepared")

        # Drop the write lock
        self.samdb.transaction_cancel()

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)
        self.assertEqual(got_pid, pid)

    def _test_full_db_lock1(self, backend_path):
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del (self.samdb)
            gc.collect()

            backenddb = ldb.Ldb(backend_path)

            backenddb.transaction_start()

            backenddb.add({"dn": "@DSDB_LOCK_TEST"})
            backenddb.delete("@DSDB_LOCK_TEST")

            # Obtain a write lock
            backenddb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            backenddb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        end = time.time()
        self.assertGreater(end - start, 1.9)

        # Release the locks
        for l in res:
            pass

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)

    def test_full_db_lock1_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)

    def _test_full_db_lock2(self, backend_path):
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:

            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session, lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)
            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # In the parent, close the main DB, re-open just one DB
        del (self.samdb)
        gc.collect()
        backenddb = ldb.Ldb(backend_path)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        backenddb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        backenddb.add({"dn": "@DSDB_LOCK_TEST"})
        backenddb.delete("@DSDB_LOCK_TEST")
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        backenddb.transaction_prepare_commit()
        end = time.time()

        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            backenddb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)

    def test_full_db_lock2_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)
Exemple #28
0
class RodcTests(samba.tests.TestCase):
    def setUp(self):
        super(RodcTests, self).setUp()
        self.samdb = SamDB(HOST,
                           credentials=CREDS,
                           session_info=system_session(LP),
                           lp=LP)

        self.base_dn = self.samdb.domain_dn()

        root = self.samdb.search(base='',
                                 scope=ldb.SCOPE_BASE,
                                 attrs=['dsServiceName'])
        self.service = root[0]['dsServiceName'][0]
        self.tag = uuid.uuid4().hex

    def test_add_replicated_objects(self):
        for o in (
            {
                'dn': "ou=%s1,%s" % (self.tag, self.base_dn),
                "objectclass": "organizationalUnit"
            },
            {
                'dn': "cn=%s2,%s" % (self.tag, self.base_dn),
                "objectclass": "user"
            },
            {
                'dn': "cn=%s3,%s" % (self.tag, self.base_dn),
                "objectclass": "group"
            },
            {
                'dn': "cn=%s4,%s" % (self.tag, self.service),
                "objectclass": "NTDSConnection",
                'enabledConnection': 'TRUE',
                'fromServer': self.base_dn,
                'options': '0'
            },
        ):
            try:
                self.samdb.add(o)
                self.fail("Failed to fail to add %s" % o['dn'])
            except ldb.LdbError as e:
                (ecode, emsg) = e.args
                if ecode != ldb.ERR_REFERRAL:
                    print(emsg)
                    self.fail("Adding %s: ldb error: %s %s, wanted referral" %
                              (o['dn'], ecode, emsg))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    try:
                        tmpdb = SamDB(address,
                                      credentials=CREDS,
                                      session_info=system_session(LP),
                                      lp=LP)
                        tmpdb.add(o)
                        tmpdb.delete(o['dn'])
                    except ldb.LdbError as e:
                        self.fail("couldn't modify referred location %s" %
                                  address)

                    if address.lower().startswith(
                            self.samdb.domain_dns_name()):
                        self.fail(
                            "referral address did not give a specific DC")

    def test_modify_replicated_attributes(self):
        # some timestamp ones
        dn = 'CN=Guest,CN=Users,' + self.base_dn
        value = 'hallooo'
        for attr in ['carLicense', 'middleName']:
            msg = ldb.Message()
            msg.dn = ldb.Dn(self.samdb, dn)
            msg[attr] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, attr)
            try:
                self.samdb.modify(msg)
                self.fail("Failed to fail to modify %s %s" % (dn, attr))
            except ldb.LdbError as e1:
                (ecode, emsg) = e1.args
                if ecode != ldb.ERR_REFERRAL:
                    self.fail("Failed to REFER when trying to modify %s %s" %
                              (dn, attr))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    try:
                        tmpdb = SamDB(address,
                                      credentials=CREDS,
                                      session_info=system_session(LP),
                                      lp=LP)
                        tmpdb.modify(msg)
                    except ldb.LdbError as e:
                        self.fail("couldn't modify referred location %s" %
                                  address)

                    if address.lower().startswith(
                            self.samdb.domain_dns_name()):
                        self.fail(
                            "referral address did not give a specific DC")

    def test_modify_nonreplicated_attributes(self):
        # some timestamp ones
        dn = 'CN=Guest,CN=Users,' + self.base_dn
        value = '123456789'
        for attr in ['badPwdCount', 'lastLogon', 'lastLogoff']:
            m = ldb.Message()
            m.dn = ldb.Dn(self.samdb, dn)
            m[attr] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, attr)
            # Windows refers these ones even though they are non-replicated
            try:
                self.samdb.modify(m)
                self.fail("Failed to fail to modify %s %s" % (dn, attr))
            except ldb.LdbError as e2:
                (ecode, emsg) = e2.args
                if ecode != ldb.ERR_REFERRAL:
                    self.fail("Failed to REFER when trying to modify %s %s" %
                              (dn, attr))
                else:
                    m = re.search(r'(ldap://[^>]+)>', emsg)
                    if m is None:
                        self.fail("referral seems not to refer to anything")
                    address = m.group(1)

                    if address.lower().startswith(
                            self.samdb.domain_dns_name()):
                        self.fail(
                            "referral address did not give a specific DC")

    def test_modify_nonreplicated_reps_attributes(self):
        # some timestamp ones
        dn = self.base_dn

        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, dn)
        attr = 'repsFrom'

        res = self.samdb.search(dn, scope=ldb.SCOPE_BASE, attrs=['repsFrom'])
        rep = ndr_unpack(drsblobs.repsFromToBlob,
                         res[0]['repsFrom'][0],
                         allow_remaining=True)
        rep.ctr.result_last_attempt = -1
        value = ndr_pack(rep)

        m[attr] = ldb.MessageElement(value, ldb.FLAG_MOD_REPLACE, attr)
        try:
            self.samdb.modify(m)
            self.fail("Failed to fail to modify %s %s" % (dn, attr))
        except ldb.LdbError as e3:
            (ecode, emsg) = e3.args
            if ecode != ldb.ERR_REFERRAL:
                self.fail("Failed to REFER when trying to modify %s %s" %
                          (dn, attr))
            else:
                m = re.search(r'(ldap://[^>]+)>', emsg)
                if m is None:
                    self.fail("referral seems not to refer to anything")
                address = m.group(1)

                if address.lower().startswith(self.samdb.domain_dns_name()):
                    self.fail("referral address did not give a specific DC")

    def test_delete_special_objects(self):
        dn = 'CN=Guest,CN=Users,' + self.base_dn
        try:
            self.samdb.delete(dn)
            self.fail("Failed to fail to delete %s" % (dn))
        except ldb.LdbError as e4:
            (ecode, emsg) = e4.args
            if ecode != ldb.ERR_REFERRAL:
                print(ecode, emsg)
                self.fail("Failed to REFER when trying to delete %s" % dn)
            else:
                m = re.search(r'(ldap://[^>]+)>', emsg)
                if m is None:
                    self.fail("referral seems not to refer to anything")
                address = m.group(1)

                if address.lower().startswith(self.samdb.domain_dns_name()):
                    self.fail("referral address did not give a specific DC")

    def test_no_delete_nonexistent_objects(self):
        dn = 'CN=does-not-exist-%s,CN=Users,%s' % (self.tag, self.base_dn)
        try:
            self.samdb.delete(dn)
            self.fail("Failed to fail to delete %s" % (dn))
        except ldb.LdbError as e5:
            (ecode, emsg) = e5.args
            if ecode != ldb.ERR_NO_SUCH_OBJECT:
                print(ecode, emsg)
                self.fail("Failed to NO_SUCH_OBJECT when trying to delete "
                          "%s (which does not exist)" % dn)
Exemple #29
0
class LATests(samba.tests.TestCase):
    def setUp(self):
        super(LATests, self).setUp()
        self.samdb = SamDB(host,
                           credentials=creds,
                           session_info=system_session(lp),
                           lp=lp)

        self.base_dn = self.samdb.domain_dn()
        self.ou = "OU=la,%s" % self.base_dn
        if opts.delete_in_setup:
            try:
                self.samdb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError as e:
                print("tried deleting %s, got error %s" % (self.ou, e))
        self.samdb.add({'objectclass': 'organizationalUnit', 'dn': self.ou})

    def tearDown(self):
        super(LATests, self).tearDown()
        if not opts.no_cleanup:
            self.samdb.delete(self.ou, ['tree_delete:1'])

    def add_object(self, cn, objectclass, more_attrs={}):
        dn = "CN=%s,%s" % (cn, self.ou)
        attrs = {'cn': cn, 'objectclass': objectclass, 'dn': dn}
        attrs.update(more_attrs)
        self.samdb.add(attrs)

        return dn

    def add_objects(self, n, objectclass, prefix=None, more_attrs={}):
        if prefix is None:
            prefix = objectclass
        dns = []
        for i in range(n):
            dns.append(
                self.add_object("%s%d" % (prefix, i + 1),
                                objectclass,
                                more_attrs=more_attrs))
        return dns

    def add_linked_attribute(self, src, dest, attr='member', controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_ADD, attr)
        self.samdb.modify(m, controls=controls)

    def remove_linked_attribute(self, src, dest, attr='member', controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_DELETE, attr)
        self.samdb.modify(m, controls=controls)

    def replace_linked_attribute(self,
                                 src,
                                 dest,
                                 attr='member',
                                 controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_REPLACE, attr)
        self.samdb.modify(m, controls=controls)

    def attr_search(self, obj, attr, scope=ldb.SCOPE_BASE, **controls):
        if opts.no_reveal_internals:
            if 'reveal_internals' in controls:
                del controls['reveal_internals']

        controls = ['%s:%d' % (k, int(v)) for k, v in controls.items()]

        res = self.samdb.search(obj,
                                scope=scope,
                                attrs=[attr],
                                controls=controls)
        return res

    def assert_links(self, obj, expected, attr, msg='', **kwargs):
        res = self.attr_search(obj, attr, **kwargs)

        if len(expected) == 0:
            if attr in res[0]:
                self.fail("found attr '%s' in %s" % (attr, res[0]))
            return

        try:
            results = list([x[attr] for x in res][0])
        except KeyError:
            self.fail("missing attr '%s' on %s" % (attr, obj))

        expected = sorted(expected)
        results = sorted(results)

        if expected != results:
            print(msg)
            print("expected %s" % expected)
            print("received %s" % results)

        self.assertEqual(results, expected)

    def assert_back_links(self, obj, expected, attr='memberOf', **kwargs):
        self.assert_links(obj,
                          expected,
                          attr=attr,
                          msg='back links do not match',
                          **kwargs)

    def assert_forward_links(self, obj, expected, attr='member', **kwargs):
        self.assert_links(obj,
                          expected,
                          attr=attr,
                          msg='forward links do not match',
                          **kwargs)

    def get_object_guid(self, dn):
        res = self.samdb.search(dn, scope=ldb.SCOPE_BASE, attrs=['objectGUID'])
        return str(misc.GUID(res[0]['objectGUID'][0]))

    def assertRaisesLdbError(self, errcode, msg, f, *args, **kwargs):
        """Assert a function raises a particular LdbError."""
        try:
            f(*args, **kwargs)
        except ldb.LdbError as e:
            (num, msg) = e.args
            if num != errcode:
                lut = {
                    v: k
                    for k, v in vars(ldb).iteritems()
                    if k.startswith('ERR_') and isinstance(v, int)
                }
                self.fail("%s, expected "
                          "LdbError %s, (%d) "
                          "got %s (%d)" %
                          (msg, lut.get(errcode), errcode, lut.get(num), num))
        else:
            lut = {
                v: k
                for k, v in vars(ldb).iteritems()
                if k.startswith('ERR_') and isinstance(v, int)
            }
            self.fail("%s, expected "
                      "LdbError %s, (%d) "
                      "but we got success" % (msg, lut.get(errcode), errcode))

    def _test_la_backlinks(self, reveal=False):
        tag = 'backlinks'
        kwargs = {}
        if reveal:
            tag += '_reveal'
            kwargs = {'reveal_internals': 0}

        u1, u2 = self.add_objects(2, 'user', 'u_%s' % tag)
        g1, g2 = self.add_objects(2, 'group', 'g_%s' % tag)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.assert_back_links(u1, [g1, g2], **kwargs)
        self.assert_back_links(u2, [g2], **kwargs)

    def test_la_backlinks(self):
        self._test_la_backlinks()

    def test_la_backlinks_reveal(self):
        if opts.no_reveal_internals:
            print('skipping because --no-reveal-internals')
            return
        self._test_la_backlinks(True)

    def _test_la_backlinks_delete_group(self, reveal=False):
        tag = 'del_group'
        kwargs = {}
        if reveal:
            tag += '_reveal'
            kwargs = {'reveal_internals': 0}

        u1, u2 = self.add_objects(2, 'user', 'u_' + tag)
        g1, g2 = self.add_objects(2, 'group', 'g_' + tag)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.samdb.delete(g2, ['tree_delete:1'])

        self.assert_back_links(u1, [g1], **kwargs)
        self.assert_back_links(u2, set(), **kwargs)

    def test_la_backlinks_delete_group(self):
        self._test_la_backlinks_delete_group()

    def test_la_backlinks_delete_group_reveal(self):
        if opts.no_reveal_internals:
            print('skipping because --no-reveal-internals')
            return
        self._test_la_backlinks_delete_group(True)

    def test_links_all_delete_group(self):
        u1, u2 = self.add_objects(2, 'user', 'u_all_del_group')
        g1, g2 = self.add_objects(2, 'group', 'g_all_del_group')
        g2guid = self.get_object_guid(g2)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.samdb.delete(g2)
        self.assert_back_links(u1, [g1],
                               show_deleted=1,
                               show_recycled=1,
                               show_deactivated_link=0)
        self.assert_back_links(u2,
                               set(),
                               show_deleted=1,
                               show_recycled=1,
                               show_deactivated_link=0)
        self.assert_forward_links(g1, [u1],
                                  show_deleted=1,
                                  show_recycled=1,
                                  show_deactivated_link=0)
        self.assert_forward_links('<GUID=%s>' % g2guid, [],
                                  show_deleted=1,
                                  show_recycled=1,
                                  show_deactivated_link=0)

    def test_links_all_delete_group_reveal(self):
        u1, u2 = self.add_objects(2, 'user', 'u_all_del_group_reveal')
        g1, g2 = self.add_objects(2, 'group', 'g_all_del_group_reveal')
        g2guid = self.get_object_guid(g2)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.samdb.delete(g2)
        self.assert_back_links(u1, [g1],
                               show_deleted=1,
                               show_recycled=1,
                               show_deactivated_link=0,
                               reveal_internals=0)
        self.assert_back_links(u2,
                               set(),
                               show_deleted=1,
                               show_recycled=1,
                               show_deactivated_link=0,
                               reveal_internals=0)
        self.assert_forward_links(g1, [u1],
                                  show_deleted=1,
                                  show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)
        self.assert_forward_links('<GUID=%s>' % g2guid, [],
                                  show_deleted=1,
                                  show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)

    def test_la_links_delete_link(self):
        u1, u2 = self.add_objects(2, 'user', 'u_del_link')
        g1, g2 = self.add_objects(2, 'group', 'g_del_link')

        res = self.samdb.search(g1, scope=ldb.SCOPE_BASE, attrs=['uSNChanged'])
        old_usn1 = int(res[0]['uSNChanged'][0])

        self.add_linked_attribute(g1, u1)

        res = self.samdb.search(g1, scope=ldb.SCOPE_BASE, attrs=['uSNChanged'])
        new_usn1 = int(res[0]['uSNChanged'][0])

        self.assertNotEqual(old_usn1, new_usn1, "USN should have incremented")

        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        res = self.samdb.search(g2, scope=ldb.SCOPE_BASE, attrs=['uSNChanged'])
        old_usn2 = int(res[0]['uSNChanged'][0])

        self.remove_linked_attribute(g2, u1)

        res = self.samdb.search(g2, scope=ldb.SCOPE_BASE, attrs=['uSNChanged'])
        new_usn2 = int(res[0]['uSNChanged'][0])

        self.assertNotEqual(old_usn2, new_usn2, "USN should have incremented")

        self.assert_forward_links(g1, [u1])
        self.assert_forward_links(g2, [u2])

        self.add_linked_attribute(g2, u1)
        self.assert_forward_links(g2, [u1, u2])
        self.remove_linked_attribute(g2, u2)
        self.assert_forward_links(g2, [u1])
        self.remove_linked_attribute(g2, u1)
        self.assert_forward_links(g2, [])
        self.remove_linked_attribute(g1, [])
        self.assert_forward_links(g1, [])

        # removing a duplicate link in the same message should fail
        self.add_linked_attribute(g2, [u1, u2])
        self.assertRaises(ldb.LdbError, self.remove_linked_attribute, g2,
                          [u1, u1])

    def _test_la_links_delete_link_reveal(self):
        u1, u2 = self.add_objects(2, 'user', 'u_del_link_reveal')
        g1, g2 = self.add_objects(2, 'group', 'g_del_link_reveal')

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.remove_linked_attribute(g2, u1)

        self.assert_forward_links(g2, [u1, u2],
                                  show_deleted=1,
                                  show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)

    def test_la_links_delete_link_reveal(self):
        if opts.no_reveal_internals:
            print('skipping because --no-reveal-internals')
            return
        self._test_la_links_delete_link_reveal()

    def test_la_links_delete_user(self):
        u1, u2 = self.add_objects(2, 'user', 'u_del_user')
        g1, g2 = self.add_objects(2, 'group', 'g_del_user')

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        res = self.samdb.search(g1, scope=ldb.SCOPE_BASE, attrs=['uSNChanged'])
        old_usn1 = int(res[0]['uSNChanged'][0])

        res = self.samdb.search(g2, scope=ldb.SCOPE_BASE, attrs=['uSNChanged'])
        old_usn2 = int(res[0]['uSNChanged'][0])

        self.samdb.delete(u1)

        self.assert_forward_links(g1, [])
        self.assert_forward_links(g2, [u2])

        res = self.samdb.search(g1, scope=ldb.SCOPE_BASE, attrs=['uSNChanged'])
        new_usn1 = int(res[0]['uSNChanged'][0])

        res = self.samdb.search(g2, scope=ldb.SCOPE_BASE, attrs=['uSNChanged'])
        new_usn2 = int(res[0]['uSNChanged'][0])

        # Assert the USN on the alternate object is unchanged
        self.assertEqual(old_usn1, new_usn1)
        self.assertEqual(old_usn2, new_usn2)

    def test_la_links_delete_user_reveal(self):
        u1, u2 = self.add_objects(2, 'user', 'u_del_user_reveal')
        g1, g2 = self.add_objects(2, 'group', 'g_del_user_reveal')

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.samdb.delete(u1)

        self.assert_forward_links(g2, [u2],
                                  show_deleted=1,
                                  show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)
        self.assert_forward_links(g1, [],
                                  show_deleted=1,
                                  show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)

    def test_multiple_links(self):
        u1, u2, u3, u4 = self.add_objects(4, 'user', 'u_multiple_links')
        g1, g2, g3, g4 = self.add_objects(4, 'group', 'g_multiple_links')

        self.add_linked_attribute(g1, [u1, u2, u3, u4])
        self.add_linked_attribute(g2, [u3, u1])
        self.add_linked_attribute(g3, u2)

        self.assertRaisesLdbError(ldb.ERR_ENTRY_ALREADY_EXISTS,
                                  "adding duplicate values",
                                  self.add_linked_attribute, g2,
                                  [u1, u2, u3, u2])

        self.assert_forward_links(g1, [u1, u2, u3, u4])
        self.assert_forward_links(g2, [u3, u1])
        self.assert_forward_links(g3, [u2])
        self.assert_back_links(u1, [g2, g1])
        self.assert_back_links(u2, [g3, g1])
        self.assert_back_links(u3, [g2, g1])
        self.assert_back_links(u4, [g1])

        self.remove_linked_attribute(g2, [u1, u3])
        self.remove_linked_attribute(g1, [u1, u3])

        self.assert_forward_links(g1, [u2, u4])
        self.assert_forward_links(g2, [])
        self.assert_forward_links(g3, [u2])
        self.assert_back_links(u1, [])
        self.assert_back_links(u2, [g3, g1])
        self.assert_back_links(u3, [])
        self.assert_back_links(u4, [g1])

        self.add_linked_attribute(g1, [u1, u3])
        self.add_linked_attribute(g2, [u3, u1])
        self.add_linked_attribute(g3, [u1, u3])

        self.assert_forward_links(g1, [u1, u2, u3, u4])
        self.assert_forward_links(g2, [u1, u3])
        self.assert_forward_links(g3, [u1, u2, u3])
        self.assert_back_links(u1, [g1, g2, g3])
        self.assert_back_links(u2, [g3, g1])
        self.assert_back_links(u3, [g3, g2, g1])
        self.assert_back_links(u4, [g1])

    def test_la_links_replace(self):
        u1, u2, u3, u4 = self.add_objects(4, 'user', 'u_replace')
        g1, g2, g3, g4 = self.add_objects(4, 'group', 'g_replace')

        self.add_linked_attribute(g1, [u1, u2])
        self.add_linked_attribute(g2, [u1, u3])
        self.add_linked_attribute(g3, u1)

        self.replace_linked_attribute(g1, [u2])
        self.replace_linked_attribute(g2, [u2, u3])
        self.replace_linked_attribute(g3, [u1, u3])
        self.replace_linked_attribute(g4, [u4])

        self.assert_forward_links(g1, [u2])
        self.assert_forward_links(g2, [u3, u2])
        self.assert_forward_links(g3, [u3, u1])
        self.assert_forward_links(g4, [u4])
        self.assert_back_links(u1, [g3])
        self.assert_back_links(u2, [g1, g2])
        self.assert_back_links(u3, [g2, g3])
        self.assert_back_links(u4, [g4])

        self.replace_linked_attribute(g1, [u1, u2, u3])
        self.replace_linked_attribute(g2, [u1])
        self.replace_linked_attribute(g3, [u2])
        self.replace_linked_attribute(g4, [])

        self.assert_forward_links(g1, [u1, u2, u3])
        self.assert_forward_links(g2, [u1])
        self.assert_forward_links(g3, [u2])
        self.assert_forward_links(g4, [])
        self.assert_back_links(u1, [g1, g2])
        self.assert_back_links(u2, [g1, g3])
        self.assert_back_links(u3, [g1])
        self.assert_back_links(u4, [])

        self.assertRaisesLdbError(ldb.ERR_ENTRY_ALREADY_EXISTS,
                                  "replacing duplicate values",
                                  self.replace_linked_attribute, g2,
                                  [u1, u2, u3, u2])

    def test_la_links_replace2(self):
        users = self.add_objects(12, 'user', 'u_replace2')
        g1, = self.add_objects(1, 'group', 'g_replace2')

        self.add_linked_attribute(g1, users[:6])
        self.assert_forward_links(g1, users[:6])
        self.replace_linked_attribute(g1, users)
        self.assert_forward_links(g1, users)
        self.replace_linked_attribute(g1, users[6:])
        self.assert_forward_links(g1, users[6:])
        self.remove_linked_attribute(g1, users[6:9])
        self.assert_forward_links(g1, users[9:])
        self.remove_linked_attribute(g1, users[9:])
        self.assert_forward_links(g1, [])

    def test_la_links_permutations(self):
        """Make sure the order in which we add links doesn't matter."""
        users = self.add_objects(3, 'user', 'u_permutations')
        groups = self.add_objects(6, 'group', 'g_permutations')

        for g, p in zip(groups, itertools.permutations(users)):
            self.add_linked_attribute(g, p)

        # everyone should be in every group
        for g in groups:
            self.assert_forward_links(g, users)

        for u in users:
            self.assert_back_links(u, groups)

        for g, p in zip(groups[::-1], itertools.permutations(users)):
            self.replace_linked_attribute(g, p)

        for g in groups:
            self.assert_forward_links(g, users)

        for u in users:
            self.assert_back_links(u, groups)

        for g, p in zip(groups, itertools.permutations(users)):
            self.remove_linked_attribute(g, p)

        for g in groups:
            self.assert_forward_links(g, [])

        for u in users:
            self.assert_back_links(u, [])

    def test_la_links_relaxed(self):
        """Check that the relax control doesn't mess with linked attributes."""
        relax_control = ['relax:0']

        users = self.add_objects(10, 'user', 'u_relax')
        groups = self.add_objects(3,
                                  'group',
                                  'g_relax',
                                  more_attrs={'member': users[:2]})
        g_relax1, g_relax2, g_uptight = groups

        # g_relax1 has all users added at once
        # g_relax2 gets them one at a time in reverse order
        # g_uptight never relaxes

        self.add_linked_attribute(g_relax1, users[2:5], controls=relax_control)

        for u in reversed(users[2:5]):
            self.add_linked_attribute(g_relax2, u, controls=relax_control)
            self.add_linked_attribute(g_uptight, u)

        for g in groups:
            self.assert_forward_links(g, users[:5])

            self.add_linked_attribute(g, users[5:7])
            self.assert_forward_links(g, users[:7])

            for u in users[7:]:
                self.add_linked_attribute(g, u)

            self.assert_forward_links(g, users)

        for u in users:
            self.assert_back_links(u, groups)

        # try some replacement permutations
        import random
        random.seed(1)
        users2 = users[:]
        for i in range(5):
            random.shuffle(users2)
            self.replace_linked_attribute(g_relax1,
                                          users2,
                                          controls=relax_control)

            self.assert_forward_links(g_relax1, users)

        for i in range(5):
            random.shuffle(users2)
            self.remove_linked_attribute(g_relax2,
                                         users2,
                                         controls=relax_control)
            self.remove_linked_attribute(g_uptight, users2)

            self.replace_linked_attribute(g_relax1, [], controls=relax_control)

            random.shuffle(users2)
            self.add_linked_attribute(g_relax2, users2, controls=relax_control)
            self.add_linked_attribute(g_uptight, users2)
            self.replace_linked_attribute(g_relax1,
                                          users2,
                                          controls=relax_control)

            self.assert_forward_links(g_relax1, users)
            self.assert_forward_links(g_relax2, users)
            self.assert_forward_links(g_uptight, users)

        for u in users:
            self.assert_back_links(u, groups)

    def test_add_all_at_once(self):
        """All these other tests are creating linked attributes after the
        objects are there. We want to test creating them all at once
        using LDIF.
        """
        users = self.add_objects(7, 'user', 'u_all_at_once')
        g1, g3 = self.add_objects(2,
                                  'group',
                                  'g_all_at_once',
                                  more_attrs={'member': users})
        (g2, ) = self.add_objects(1,
                                  'group',
                                  'g_all_at_once2',
                                  more_attrs={'member': users[:5]})

        self.assertRaisesLdbError(
            ldb.ERR_ENTRY_ALREADY_EXISTS,
            "adding multiple duplicate values",
            self.add_objects,
            1,
            'group',
            'g_with_duplicate_links',
            more_attrs={'member': users[:5] + users[1:2]})

        self.assert_forward_links(g1, users)
        self.assert_forward_links(g2, users[:5])
        self.assert_forward_links(g3, users)
        for u in users[:5]:
            self.assert_back_links(u, [g1, g2, g3])
        for u in users[5:]:
            self.assert_back_links(u, [g1, g3])

        self.remove_linked_attribute(g2, users[0])
        self.remove_linked_attribute(g2, users[1])
        self.add_linked_attribute(g2, users[1])
        self.add_linked_attribute(g2, users[5])
        self.add_linked_attribute(g2, users[6])

        self.assert_forward_links(g1, users)
        self.assert_forward_links(g2, users[1:])

        for u in users[1:]:
            self.remove_linked_attribute(g2, u)
        self.remove_linked_attribute(g1, users)

        for u in users:
            self.samdb.delete(u)

        self.assert_forward_links(g1, [])
        self.assert_forward_links(g2, [])
        self.assert_forward_links(g3, [])

    def test_one_way_attributes(self):
        e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',
                                  'e_one_way')
        guid = self.get_object_guid(e2)

        self.add_linked_attribute(e1, e2, attr="addressBookRoots")
        self.assert_forward_links(e1, [e2], attr='addressBookRoots')

        self.samdb.delete(e2)

        res = self.samdb.search("<GUID=%s>" % guid,
                                scope=ldb.SCOPE_BASE,
                                controls=['show_deleted:1', 'show_recycled:1'])

        new_dn = str(res[0].dn)
        self.assert_forward_links(e1, [new_dn], attr='addressBookRoots')
        self.assert_forward_links(e1, [new_dn],
                                  attr='addressBookRoots',
                                  show_deactivated_link=0)

    def test_one_way_attributes_delete_link(self):
        e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',
                                  'e_one_way')
        guid = self.get_object_guid(e2)

        self.add_linked_attribute(e1, e2, attr="addressBookRoots")
        self.assert_forward_links(e1, [e2], attr='addressBookRoots')

        self.remove_linked_attribute(e1, e2, attr="addressBookRoots")

        self.assert_forward_links(e1, [], attr='addressBookRoots')
        self.assert_forward_links(e1, [],
                                  attr='addressBookRoots',
                                  show_deactivated_link=0)

    def test_pretend_one_way_attributes(self):
        e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',
                                  'e_one_way')
        guid = self.get_object_guid(e2)

        self.add_linked_attribute(e1, e2, attr="addressBookRoots2")
        self.assert_forward_links(e1, [e2], attr='addressBookRoots2')

        self.samdb.delete(e2)
        res = self.samdb.search("<GUID=%s>" % guid,
                                scope=ldb.SCOPE_BASE,
                                controls=['show_deleted:1', 'show_recycled:1'])

        new_dn = str(res[0].dn)

        self.assert_forward_links(e1, [], attr='addressBookRoots2')
        self.assert_forward_links(e1, [],
                                  attr='addressBookRoots2',
                                  show_deactivated_link=0)

    def test_pretend_one_way_attributes_delete_link(self):
        e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',
                                  'e_one_way')
        guid = self.get_object_guid(e2)

        self.add_linked_attribute(e1, e2, attr="addressBookRoots2")
        self.assert_forward_links(e1, [e2], attr='addressBookRoots2')

        self.remove_linked_attribute(e1, e2, attr="addressBookRoots2")

        self.assert_forward_links(e1, [], attr='addressBookRoots2')
        self.assert_forward_links(e1, [],
                                  attr='addressBookRoots2',
                                  show_deactivated_link=0)
Exemple #30
0
class VLVTests(samba.tests.TestCase):
    def create_user(self, i, n, prefix='vlvtest', suffix='', attrs=None):
        name = "%s%d%s" % (prefix, i, suffix)
        user = {
            'cn':
            name,
            "objectclass":
            "user",
            'givenName':
            "abcdefghijklmnopqrstuvwxyz"[i % 26],
            "roomNumber":
            "%sbc" % (n - i),
            "carLicense":
            "后来经",
            "employeeNumber":
            "%s%sx" % (abs(i * (99 - i)), '\n' * (i & 255)),
            "accountExpires":
            "%s" % (10**9 + 1000000 * i),
            "msTSExpireDate4":
            "19%02d0101010000.0Z" % (i % 100),
            "flags":
            str(i * (n - i)),
            "serialNumber":
            "abc %s%s%s" % (
                'AaBb |-/'[i & 7],
                ' 3z}'[i & 3],
                '"@'[i & 1],
            ),
        }

        # _user_broken_attrs tests are broken due to problems outside
        # of VLV.
        _user_broken_attrs = {
            # Sort doesn't look past a NUL byte.
            "photo":
            "\x00%d" % (n - i),
            "audio":
            "%sn octet string %s%s ♫♬\x00lalala" %
            ('Aa'[i & 1], chr(i & 255), i),
            "displayNamePrintable":
            "%d\x00%c" % (i, i & 255),
            "adminDisplayName":
            "%d\x00b" % (n - i),
            "title":
            "%d%sb" % (n - i, '\x00' * i),
            "comment":
            "Favourite colour is %d" % (n % (i + 1)),

            # Names that vary only in case. Windows returns
            # equivalent addresses in the order they were put
            # in ('a st', 'A st',...).
            "street":
            "%s st" % (chr(65 | (i & 14) | ((i & 1) * 32))),
        }

        if attrs is not None:
            user.update(attrs)

        user['dn'] = "cn=%s,%s" % (user['cn'], self.ou)

        if opts.skip_attr_regex:
            match = re.compile(opts.skip_attr_regex).search
            for k in user.keys():
                if match(k):
                    del user[k]

        self.users.append(user)
        self.ldb.add(user)
        return user

    def setUp(self):
        super(VLVTests, self).setUp()
        self.ldb = SamDB(host,
                         credentials=creds,
                         session_info=system_session(lp),
                         lp=lp)

        self.base_dn = self.ldb.domain_dn()
        self.ou = "ou=vlv,%s" % self.base_dn
        if opts.delete_in_setup:
            try:
                self.ldb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError as e:
                print "tried deleting %s, got error %s" % (self.ou, e)
        self.ldb.add({"dn": self.ou, "objectclass": "organizationalUnit"})

        self.users = []
        for i in range(N_ELEMENTS):
            self.create_user(i, N_ELEMENTS)

        attrs = self.users[0].keys()
        self.binary_sorted_keys = [
            'audio', 'photo', "msTSExpireDate4", 'serialNumber',
            "displayNamePrintable"
        ]

        self.numeric_sorted_keys = ['flags', 'accountExpires']

        self.timestamp_keys = ['msTSExpireDate4']

        self.int64_keys = set(['accountExpires'])

        self.locale_sorted_keys = [
            x for x in attrs
            if x not in (self.binary_sorted_keys + self.numeric_sorted_keys)
        ]

        # don't try spaces, etc in cn
        self.delicate_keys = ['cn']

    def tearDown(self):
        super(VLVTests, self).tearDown()
        if not opts.delete_in_setup:
            self.ldb.delete(self.ou, ['tree_delete:1'])

    def get_full_list(self, attr, include_cn=False):
        """Fetch the whole list sorted on the attribute, using the VLV.
        This way you get a VLV cookie."""
        n_users = len(self.users)
        sort_control = "server_sort:1:0:%s" % attr
        half_n = n_users // 2
        vlv_search = "vlv:1:%d:%d:%d:0" % (half_n, half_n, half_n + 1)
        attrs = [attr]
        if include_cn:
            attrs.append('cn')
        res = self.ldb.search(self.ou,
                              scope=ldb.SCOPE_ONELEVEL,
                              attrs=attrs,
                              controls=[sort_control, vlv_search])
        if include_cn:
            full_results = [(x[attr][0], x['cn'][0]) for x in res]
        else:
            full_results = [x[attr][0].lower() for x in res]
        controls = res.controls
        return full_results, controls, sort_control

    def get_expected_order(self, attr, expression=None):
        """Fetch the whole list sorted on the attribute, using sort only."""
        sort_control = "server_sort:1:0:%s" % attr
        res = self.ldb.search(self.ou,
                              scope=ldb.SCOPE_ONELEVEL,
                              expression=expression,
                              attrs=[attr],
                              controls=[sort_control])
        results = [x[attr][0] for x in res]
        return results

    def delete_user(self, user):
        self.ldb.delete(user['dn'])
        del self.users[self.users.index(user)]

    def get_gte_tests_and_order(self, attr, expression=None):
        expected_order = self.get_expected_order(attr, expression=expression)
        gte_users = []
        if attr in self.delicate_keys:
            gte_keys = [
                '3',
                'abc',
                '¹',
                'ŋđ¼³ŧ“«đð',
                '桑巴',
            ]
        elif attr in self.timestamp_keys:
            gte_keys = [
                '18560101010000.0Z',
                '19140103010000.0Z',
                '19560101010010.0Z',
                '19700101000000.0Z',
                '19991231211234.3Z',
                '20061111211234.0Z',
                '20390901041234.0Z',
                '25560101010000.0Z',
            ]
        elif attr not in self.numeric_sorted_keys:
            gte_keys = [
                '3',
                'abc',
                ' ',
                '!@#!@#!',
                'kōkako',
                '¹',
                'ŋđ¼³ŧ“«đð',
                '\n\t\t',
                '桑巴',
                'zzzz',
            ]
            if expected_order:
                gte_keys.append(expected_order[len(expected_order) // 2] +
                                ' tail')

        else:
            # "numeric" means positive integers
            # doesn't work with -1, 3.14, ' 3', '9' * 20
            gte_keys = ['3', '1' * 10, '1', '9' * 7, '0']

            if attr in self.int64_keys:
                gte_keys += ['3' * 12, '71' * 8]

        for i, x in enumerate(gte_keys):
            user = self.create_user(i,
                                    N_ELEMENTS,
                                    prefix='gte',
                                    attrs={attr: x})
            gte_users.append(user)

        gte_order = self.get_expected_order(attr)
        for user in gte_users:
            self.delete_user(user)

        # for sanity's sake
        expected_order_2 = self.get_expected_order(attr, expression=expression)
        self.assertEqual(expected_order, expected_order_2)

        # Map gte tests to indexes in expected order. This will break
        # if gte_order and expected_order are differently ordered (as
        # it should).
        gte_map = {}

        # index to the first one with each value
        index_map = {}
        for i, k in enumerate(expected_order):
            if k not in index_map:
                index_map[k] = i

        keys = []
        for k in gte_order:
            if k in index_map:
                i = index_map[k]
                gte_map[k] = i
                for k in keys:
                    gte_map[k] = i
                keys = []
            else:
                keys.append(k)

        for k in keys:
            gte_map[k] = len(expected_order)

        if False:
            print "gte_map:"
            for k in gte_order:
                print "   %10s => %10s" % (k, gte_map[k])

        return gte_order, expected_order, gte_map

    def assertCorrectResults(self, results, expected_order, offset, before,
                             after):
        """A helper to calculate offsets correctly and say as much as possible
        when something goes wrong."""

        start = max(offset - before - 1, 0)
        end = offset + after
        expected_results = expected_order[start:end]

        # if it is a tuple with the cn, drop the cn
        if expected_results and isinstance(expected_results[0], tuple):
            expected_results = [x[0] for x in expected_results]

        if expected_results == results:
            return

        if expected_order is not None:
            print "expected order: %s" % expected_order[:20]
            if len(expected_order) > 20:
                print "... and %d more not shown" % (len(expected_order) - 20)

        print "offset %d before %d after %d" % (offset, before, after)
        print "start %d end %d" % (start, end)
        print "expected: %s" % expected_results
        print "got     : %s" % results
        self.assertEquals(expected_results, results)

    def test_server_vlv_with_cookie(self):
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]
        for attr in attrs:
            expected_order = self.get_expected_order(attr)
            sort_control = "server_sort:1:0:%s" % attr
            res = None
            n = len(self.users)
            for before in [10, 0, 3, 1, 4, 5, 2]:
                for after in [0, 3, 1, 4, 5, 2, 7]:
                    for offset in range(max(1, before - 2),
                                        min(n - after + 2, n)):
                        if res is None:
                            vlv_search = "vlv:1:%d:%d:%d:0" % (before, after,
                                                               offset)
                        else:
                            cookie = get_cookie(res.controls, n)
                            vlv_search = ("vlv:1:%d:%d:%d:%s:%s" %
                                          (before, after, offset, n, cookie))

                        res = self.ldb.search(
                            self.ou,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=[attr],
                            controls=[sort_control, vlv_search])

                        results = [x[attr][0] for x in res]

                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def run_index_tests_with_expressions(self, expressions):
        # Here we don't test every before/after combination.
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]
        for attr in attrs:
            for expression in expressions:
                expected_order = self.get_expected_order(attr, expression)
                sort_control = "server_sort:1:0:%s" % attr
                res = None
                n = len(expected_order)
                for before in range(0, 11):
                    after = before
                    for offset in range(max(1, before - 2),
                                        min(n - after + 2, n)):
                        if res is None:
                            vlv_search = "vlv:1:%d:%d:%d:0" % (before, after,
                                                               offset)
                        else:
                            cookie = get_cookie(res.controls)
                            vlv_search = ("vlv:1:%d:%d:%d:%s:%s" %
                                          (before, after, offset, n, cookie))

                        res = self.ldb.search(
                            self.ou,
                            expression=expression,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=[attr],
                            controls=[sort_control, vlv_search])

                        results = [x[attr][0] for x in res]

                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def test_server_vlv_with_expression(self):
        """What happens when we run the VLV with an expression?"""
        expressions = [
            "(objectClass=*)",
            "(cn=%s)" % self.users[-1]['cn'],
            "(roomNumber=%s)" % self.users[0]['roomNumber'],
        ]
        self.run_index_tests_with_expressions(expressions)

    def test_server_vlv_with_failing_expression(self):
        """What happens when we run the VLV on an expression that matches
        nothing?"""
        expressions = [
            "(samaccountname=testferf)",
            "(cn=hefalump)",
        ]
        self.run_index_tests_with_expressions(expressions)

    def run_gte_tests_with_expressions(self, expressions):
        # Here we don't test every before/after combination.
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]
        for expression in expressions:
            for attr in attrs:
                gte_order, expected_order, gte_map = \
                    self.get_gte_tests_and_order(attr, expression)
                # In case there is some order dependency, disorder tests
                gte_tests = gte_order[:]
                random.seed(2)
                random.shuffle(gte_tests)
                res = None
                sort_control = "server_sort:1:0:%s" % attr

                expected_order = self.get_expected_order(attr, expression)
                sort_control = "server_sort:1:0:%s" % attr
                res = None
                for before in range(0, 11):
                    after = before
                    for gte in gte_tests:
                        if res is not None:
                            cookie = get_cookie(res.controls)
                        else:
                            cookie = None
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        gte=gte,
                                                        cookie=cookie)

                        res = self.ldb.search(
                            self.ou,
                            scope=ldb.SCOPE_ONELEVEL,
                            expression=expression,
                            attrs=[attr],
                            controls=[sort_control, vlv_search])

                        results = [x[attr][0] for x in res]
                        offset = gte_map.get(gte, len(expected_order))

                        # here offset is 0-based
                        start = max(offset - before, 0)
                        end = offset + 1 + after

                        expected_results = expected_order[start:end]

                        self.assertEquals(expected_results, results)

    def test_vlv_gte_with_expression(self):
        """What happens when we run the VLV with an expression?"""
        expressions = [
            "(objectClass=*)",
            "(cn=%s)" % self.users[-1]['cn'],
            "(roomNumber=%s)" % self.users[0]['roomNumber'],
        ]
        self.run_gte_tests_with_expressions(expressions)

    def test_vlv_gte_with_failing_expression(self):
        """What happens when we run the VLV on an expression that matches
        nothing?"""
        expressions = [
            "(samaccountname=testferf)",
            "(cn=hefalump)",
        ]
        self.run_gte_tests_with_expressions(expressions)

    def test_server_vlv_with_cookie_while_adding_and_deleting(self):
        """What happens if we add or remove items in the middle of the VLV?

        Nothing. The search and the sort is not repeated, and we only
        deal with the objects originally found.
        """
        attrs = ['cn'] + [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]
        user_number = 0
        iteration = 0
        for attr in attrs:
            full_results, controls, sort_control = \
                            self.get_full_list(attr, True)
            original_n = len(self.users)

            expected_order = full_results
            random.seed(1)

            for before in range(0, 3) + [6, 11, 19]:
                for after in range(0, 3) + [6, 11, 19]:
                    start = max(before - 1, 1)
                    end = max(start + 4, original_n - after + 2)
                    for offset in range(start, end):
                        #if iteration > 2076:
                        #    return
                        cookie = get_cookie(controls, original_n)
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        offset=offset,
                                                        n=original_n,
                                                        cookie=cookie)

                        iteration += 1
                        res = self.ldb.search(
                            self.ou,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=[attr],
                            controls=[sort_control, vlv_search])

                        controls = res.controls
                        results = [x[attr][0] for x in res]
                        real_offset = max(1, min(offset, len(expected_order)))

                        expected_results = []
                        skipped = 0
                        begin_offset = max(real_offset - before - 1, 0)
                        real_before = min(before, real_offset - 1)
                        real_after = min(after,
                                         len(expected_order) - real_offset)

                        for x in expected_order[begin_offset:]:
                            if x is not None:
                                expected_results.append(x[0])
                                if (len(expected_results) == real_before +
                                        real_after + 1):
                                    break
                            else:
                                skipped += 1

                        if expected_results != results:
                            print("attr %s before %d after %d offset %d" %
                                  (attr, before, after, offset))
                        self.assertEquals(expected_results, results)

                        n = len(self.users)
                        if random.random() < 0.1 + (n < 5) * 0.05:
                            if n == 0:
                                i = 0
                            else:
                                i = random.randrange(n)
                            user = self.create_user(i,
                                                    n,
                                                    suffix='-%s' % user_number)
                            user_number += 1
                        if random.random() < 0.1 + (n > 50) * 0.02 and n:
                            index = random.randrange(n)
                            user = self.users.pop(index)

                            self.ldb.delete(user['dn'])

                            replaced = (user[attr], user['cn'])
                            if replaced in expected_order:
                                i = expected_order.index(replaced)
                                expected_order[i] = None

    def test_server_vlv_with_cookie_while_changing(self):
        """What happens if we modify items in the middle of the VLV?

        The expected behaviour (as found on Windows) is the sort is
        not repeated, but the changes in attributes are reflected.
        """
        attrs = [
            x for x in self.users[0].keys()
            if x not in ('dn', 'objectclass', 'cn')
        ]
        for attr in attrs:
            n_users = len(self.users)
            expected_order = [x.upper() for x in self.get_expected_order(attr)]
            sort_control = "server_sort:1:0:%s" % attr
            res = None
            i = 0

            # First we'll fetch the whole list so we know the original
            # sort order. This is necessary because we don't know how
            # the server will order equivalent items. We are using the
            # dn as a key.
            half_n = n_users // 2
            vlv_search = "vlv:1:%d:%d:%d:0" % (half_n, half_n, half_n + 1)
            res = self.ldb.search(self.ou,
                                  scope=ldb.SCOPE_ONELEVEL,
                                  attrs=['dn', attr],
                                  controls=[sort_control, vlv_search])

            results = [x[attr][0].upper() for x in res]
            #self.assertEquals(expected_order, results)

            dn_order = [str(x['dn']) for x in res]
            values = results[:]

            for before in range(0, 3):
                for after in range(0, 3):
                    for offset in range(1 + before, n_users - after):
                        cookie = get_cookie(res.controls, len(self.users))
                        vlv_search = (
                            "vlv:1:%d:%d:%d:%s:%s" %
                            (before, after, offset, len(self.users), cookie))

                        res = self.ldb.search(
                            self.ou,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=['dn', attr],
                            controls=[sort_control, vlv_search])

                        dn_results = [str(x['dn']) for x in res]
                        dn_expected = dn_order[offset - before - 1:offset +
                                               after]

                        self.assertEquals(dn_expected, dn_results)

                        results = [x[attr][0].upper() for x in res]

                        self.assertCorrectResults(results, values, offset,
                                                  before, after)

                        i += 1
                        if i % 3 == 2:
                            if (attr in self.locale_sorted_keys
                                    or attr in self.binary_sorted_keys):
                                i1 = i % n_users
                                i2 = (i ^ 255) % n_users
                                dn1 = dn_order[i1]
                                dn2 = dn_order[i2]
                                v2 = values[i2]

                                if v2 in self.locale_sorted_keys:
                                    v2 += '-%d' % i
                                cn1 = dn1.split(',', 1)[0][3:]
                                cn2 = dn2.split(',', 1)[0][3:]

                                values[i1] = v2

                                m = ldb.Message()
                                m.dn = ldb.Dn(self.ldb, dn1)
                                m[attr] = ldb.MessageElement(
                                    v2, ldb.FLAG_MOD_REPLACE, attr)

                                self.ldb.modify(m)

    def test_server_vlv_fractions_with_cookie(self):
        """What happens when the count is set to an arbitrary number?

        In that case the offset and the count form a fraction, and the
        VLV should be centred at a point offset/count of the way
        through. For example, if offset is 3 and count is 6, the VLV
        should be looking around halfway. The actual algorithm is a
        bit fiddlier than that, because of the one-basedness of VLV.
        """
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]

        n_users = len(self.users)

        random.seed(4)

        for attr in attrs:
            full_results, controls, sort_control = self.get_full_list(attr)
            self.assertEqual(len(full_results), n_users)
            for before in range(0, 2):
                for after in range(0, 2):
                    for denominator in range(1, 20):
                        for offset in range(1, denominator + 3):
                            cookie = get_cookie(controls, len(self.users))
                            vlv_search = (
                                "vlv:1:%d:%d:%d:%s:%s" %
                                (before, after, offset, denominator, cookie))
                            try:
                                res = self.ldb.search(
                                    self.ou,
                                    scope=ldb.SCOPE_ONELEVEL,
                                    attrs=[attr],
                                    controls=[sort_control, vlv_search])
                            except ldb.LdbError as e:
                                if offset != 0:
                                    raise
                                print(
                                    "offset %d denominator %d raised error "
                                    "expected error %s\n"
                                    "(offset zero is illegal unless "
                                    "content count is zero)" %
                                    (offset, denominator, e))
                                continue

                            results = [x[attr][0].lower() for x in res]

                            if denominator == 0:
                                denominator = n_users
                                if offset == 0:
                                    offset = denominator
                            elif denominator == 1:
                                # the offset can only be 1, but the 1/1 case
                                # means something special
                                if offset == 1:
                                    real_offset = n_users
                                else:
                                    real_offset = 1
                            else:
                                if offset > denominator:
                                    offset = denominator
                                real_offset = (1 + int(
                                    round((n_users - 1) * (offset - 1) /
                                          (denominator - 1.0))))

                            self.assertCorrectResults(results, full_results,
                                                      real_offset, before,
                                                      after)

                            controls = res.controls
                            if False:
                                for c in list(controls):
                                    cstr = str(c)
                                    if cstr.startswith('vlv_resp'):
                                        bits = cstr.rsplit(':')
                                        print("the answer is %s; we said %d" %
                                              (bits[2], real_offset))
                                        break

    def test_server_vlv_no_cookie(self):
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]

        for attr in attrs:
            expected_order = self.get_expected_order(attr)
            sort_control = "server_sort:1:0:%s" % attr
            for before in range(0, 5):
                for after in range(0, 7):
                    for offset in range(1 + before, len(self.users) - after):
                        res = self.ldb.search(self.ou,
                                              scope=ldb.SCOPE_ONELEVEL,
                                              attrs=[attr],
                                              controls=[
                                                  sort_control,
                                                  "vlv:1:%d:%d:%d:0" %
                                                  (before, after, offset)
                                              ])
                        results = [x[attr][0] for x in res]
                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def get_expected_order_showing_deleted(
            self,
            attr,
            expression="(|(cn=vlvtest*)(cn=vlv-deleted*))",
            base=None,
            scope=ldb.SCOPE_SUBTREE):
        """Fetch the whole list sorted on the attribute, using sort only,
        searching in the entire tree, not just our OU. This is the
        way to find deleted objects.
        """
        if base is None:
            base = self.base_dn
        sort_control = "server_sort:1:0:%s" % attr
        controls = [sort_control, "show_deleted:1"]

        res = self.ldb.search(base,
                              scope=scope,
                              expression=expression,
                              attrs=[attr],
                              controls=controls)
        results = [x[attr][0] for x in res]
        return results

    def add_deleted_users(self, n):
        deleted_users = [
            self.create_user(i, n, prefix='vlv-deleted') for i in range(n)
        ]

        for user in deleted_users:
            self.delete_user(user)

    def test_server_vlv_no_cookie_show_deleted(self):
        """What do we see with the show_deleted control?"""
        attrs = [
            'objectGUID', 'cn', 'sAMAccountName', 'objectSid', 'name',
            'whenChanged', 'usnChanged'
        ]

        # add some deleted users first, just in case there are none
        self.add_deleted_users(6)
        random.seed(22)
        expression = "(|(cn=vlvtest*)(cn=vlv-deleted*))"

        for attr in attrs:
            show_deleted_control = "show_deleted:1"
            expected_order = self.get_expected_order_showing_deleted(
                attr, expression)
            n = len(expected_order)
            sort_control = "server_sort:1:0:%s" % attr
            for before in [3, 1, 0]:
                for after in [0, 2]:
                    # don't test every position, because there could be hundreds.
                    # jump back and forth instead
                    for i in range(20):
                        offset = random.randrange(max(1, before - 2),
                                                  min(n - after + 2, n))
                        res = self.ldb.search(
                            self.base_dn,
                            expression=expression,
                            scope=ldb.SCOPE_SUBTREE,
                            attrs=[attr],
                            controls=[
                                sort_control, show_deleted_control,
                                "vlv:1:%d:%d:%d:0" % (before, after, offset)
                            ])
                        results = [x[attr][0] for x in res]
                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def test_server_vlv_no_cookie_show_deleted_only(self):
        """What do we see with the show_deleted control when we're not looking
        at any non-deleted things"""
        attrs = [
            'objectGUID',
            'cn',
            'sAMAccountName',
            'objectSid',
            'whenChanged',
        ]

        # add some deleted users first, just in case there are none
        self.add_deleted_users(4)
        base = 'CN=Deleted Objects,%s' % self.base_dn
        expression = "(cn=vlv-deleted*)"
        for attr in attrs:
            show_deleted_control = "show_deleted:1"
            expected_order = self.get_expected_order_showing_deleted(
                attr,
                expression=expression,
                base=base,
                scope=ldb.SCOPE_ONELEVEL)
            print("searching for attr %s amongst %d deleted objects" %
                  (attr, len(expected_order)))
            sort_control = "server_sort:1:0:%s" % attr
            step = max(len(expected_order) // 10, 1)
            for before in [3, 0]:
                for after in [0, 2]:
                    for offset in range(1 + before,
                                        len(expected_order) - after, step):
                        res = self.ldb.search(
                            base,
                            expression=expression,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=[attr],
                            controls=[
                                sort_control, show_deleted_control,
                                "vlv:1:%d:%d:%d:0" % (before, after, offset)
                            ])
                        results = [x[attr][0] for x in res]
                        self.assertCorrectResults(results, expected_order,
                                                  offset, before, after)

    def test_server_vlv_with_cookie_show_deleted(self):
        """What do we see with the show_deleted control?"""
        attrs = [
            'objectGUID', 'cn', 'sAMAccountName', 'objectSid', 'name',
            'whenChanged', 'usnChanged'
        ]
        self.add_deleted_users(6)
        random.seed(23)
        for attr in attrs:
            expected_order = self.get_expected_order(attr)
            sort_control = "server_sort:1:0:%s" % attr
            res = None
            show_deleted_control = "show_deleted:1"
            expected_order = self.get_expected_order_showing_deleted(attr)
            n = len(expected_order)
            expression = "(|(cn=vlvtest*)(cn=vlv-deleted*))"
            for before in [3, 2, 1, 0]:
                after = before
                for i in range(20):
                    offset = random.randrange(max(1, before - 2),
                                              min(n - after + 2, n))
                    if res is None:
                        vlv_search = "vlv:1:%d:%d:%d:0" % (before, after,
                                                           offset)
                    else:
                        cookie = get_cookie(res.controls, n)
                        vlv_search = ("vlv:1:%d:%d:%d:%s:%s" %
                                      (before, after, offset, n, cookie))

                    res = self.ldb.search(self.base_dn,
                                          expression=expression,
                                          scope=ldb.SCOPE_SUBTREE,
                                          attrs=[attr],
                                          controls=[
                                              sort_control, vlv_search,
                                              show_deleted_control
                                          ])

                    results = [x[attr][0] for x in res]

                    self.assertCorrectResults(results, expected_order, offset,
                                              before, after)

    def test_server_vlv_gte_with_cookie(self):
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]
        for attr in attrs:
            gte_order, expected_order, gte_map = \
                                        self.get_gte_tests_and_order(attr)
            # In case there is some order dependency, disorder tests
            gte_tests = gte_order[:]
            random.seed(1)
            random.shuffle(gte_tests)
            res = None
            sort_control = "server_sort:1:0:%s" % attr
            for before in [0, 1, 2, 4]:
                for after in [0, 1, 3, 6]:
                    for gte in gte_tests:
                        if res is not None:
                            cookie = get_cookie(res.controls, len(self.users))
                        else:
                            cookie = None
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        gte=gte,
                                                        cookie=cookie)

                        res = self.ldb.search(
                            self.ou,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=[attr],
                            controls=[sort_control, vlv_search])

                        results = [x[attr][0] for x in res]
                        offset = gte_map.get(gte, len(expected_order))

                        # here offset is 0-based
                        start = max(offset - before, 0)
                        end = offset + 1 + after

                        expected_results = expected_order[start:end]

                        self.assertEquals(expected_results, results)

    def test_server_vlv_gte_no_cookie(self):
        attrs = [
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ]
        iteration = 0
        for attr in attrs:
            gte_order, expected_order, gte_map = \
                                        self.get_gte_tests_and_order(attr)
            # In case there is some order dependency, disorder tests
            gte_tests = gte_order[:]
            random.seed(1)
            random.shuffle(gte_tests)

            sort_control = "server_sort:1:0:%s" % attr
            for before in [0, 1, 3]:
                for after in [0, 4]:
                    for gte in gte_tests:
                        vlv_search = encode_vlv_control(before=before,
                                                        after=after,
                                                        gte=gte)

                        res = self.ldb.search(
                            self.ou,
                            scope=ldb.SCOPE_ONELEVEL,
                            attrs=[attr],
                            controls=[sort_control, vlv_search])
                        results = [x[attr][0] for x in res]

                        # here offset is 0-based
                        offset = gte_map.get(gte, len(expected_order))
                        start = max(offset - before, 0)
                        end = offset + after + 1
                        expected_results = expected_order[start:end]
                        iteration += 1
                        if expected_results != results:
                            middle = expected_order[len(expected_order) // 2]
                            print expected_results, results
                            print middle
                            print expected_order
                            print
                            print(
                                "\nattr %s offset %d before %d "
                                "after %d gte %s" %
                                (attr, offset, before, after, gte))
                        self.assertEquals(expected_results, results)

    def test_multiple_searches(self):
        """The maximum number of concurrent vlv searches per connection is
        currently set at 3. That means if you open 4 VLV searches the
        cookie on the first one should fail.
        """
        # Windows has a limit of 10 VLVs where there are low numbers
        # of objects in each search.
        attrs = ([
            x for x in self.users[0].keys() if x not in ('dn', 'objectclass')
        ] * 2)[:12]

        vlv_cookies = []
        for attr in attrs:
            sort_control = "server_sort:1:0:%s" % attr

            res = self.ldb.search(self.ou,
                                  scope=ldb.SCOPE_ONELEVEL,
                                  attrs=[attr],
                                  controls=[sort_control, "vlv:1:1:1:1:0"])

            cookie = get_cookie(res.controls, len(self.users))
            vlv_cookies.append(cookie)
            time.sleep(0.2)

        # now this one should fail
        self.assertRaises(
            ldb.LdbError,
            self.ldb.search,
            self.ou,
            scope=ldb.SCOPE_ONELEVEL,
            attrs=[attr],
            controls=[sort_control,
                      "vlv:1:1:1:1:0:%s" % vlv_cookies[0]])

        # and this one should succeed
        res = self.ldb.search(
            self.ou,
            scope=ldb.SCOPE_ONELEVEL,
            attrs=[attr],
            controls=[sort_control,
                      "vlv:1:1:1:1:0:%s" % vlv_cookies[-1]])

        # this one should fail because it is a new connection and
        # doesn't share cookies
        new_ldb = SamDB(host,
                        credentials=creds,
                        session_info=system_session(lp),
                        lp=lp)

        self.assertRaises(
            ldb.LdbError,
            new_ldb.search,
            self.ou,
            scope=ldb.SCOPE_ONELEVEL,
            attrs=[attr],
            controls=[sort_control,
                      "vlv:1:1:1:1:0:%s" % vlv_cookies[-1]])

        # but now without the critical flag it just does no VLV.
        new_ldb.search(
            self.ou,
            scope=ldb.SCOPE_ONELEVEL,
            attrs=[attr],
            controls=[sort_control,
                      "vlv:0:1:1:1:0:%s" % vlv_cookies[-1]])
Exemple #31
0
class DsdbTests(TestCase):
    def setUp(self):
        super(DsdbTests, self).setUp()
        self.lp = samba.tests.env_loadparm()
        self.creds = Credentials()
        self.creds.guess(self.lp)
        self.session = system_session()
        self.samdb = SamDB(session_info=self.session,
                           credentials=self.creds,
                           lp=self.lp)

    def test_get_oid_from_attrid(self):
        oid = self.samdb.get_oid_from_attid(591614)
        self.assertEquals(oid, "1.2.840.113556.1.4.1790")

    def test_error_replpropertymetadata(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          str(res[0]["replPropertyMetaData"]))
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_nochange(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          str(res[0]["replPropertyMetaData"]))
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_error_replpropertymetadata_allow_sort(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["replPropertyMetaData"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          str(res[0]["replPropertyMetaData"]))
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, [
            "local_oid:1.3.6.1.4.1.7165.4.3.14:0",
            "local_oid:1.3.6.1.4.1.7165.4.3.25:0"
        ])

    def test_twoatt_replpropertymetadata(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          str(res[0]["replPropertyMetaData"]))
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = long(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        msg["description"] = ldb.MessageElement("new val",
                                                ldb.FLAG_MOD_REPLACE,
                                                "description")
        self.assertRaises(ldb.LdbError, self.samdb.modify, msg,
                          ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_set_replpropertymetadata(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["replPropertyMetaData", "uSNChanged"])
        repl = ndr_unpack(drsblobs.replPropertyMetaDataBlob,
                          str(res[0]["replPropertyMetaData"]))
        ctr = repl.ctr
        for o in ctr.array:
            # Search for Description
            if o.attid == 13:
                old_version = o.version
                o.version = o.version + 1
                o.local_usn = long(str(res[0]["uSNChanged"])) + 1
                o.originating_usn = long(str(res[0]["uSNChanged"])) + 1
        replBlob = ndr_pack(repl)
        msg = ldb.Message()
        msg.dn = res[0].dn
        msg["replPropertyMetaData"] = ldb.MessageElement(
            replBlob, ldb.FLAG_MOD_REPLACE, "replPropertyMetaData")
        self.samdb.modify(msg, ["local_oid:1.3.6.1.4.1.7165.4.3.14:0"])

    def test_ok_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(13),
                          "description")

    def test_ko_get_attribute_from_attid(self):
        self.assertEquals(self.samdb.get_attribute_from_attid(11979), None)

    def test_get_attribute_replmetadata_version(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "unicodePwd"), 1)

    def test_set_attribute_replmetadata_version(self):
        res = self.samdb.search(expression="cn=Administrator",
                                scope=ldb.SCOPE_SUBTREE,
                                attrs=["dn"])
        self.assertEquals(len(res), 1)
        dn = str(res[0].dn)
        version = self.samdb.get_attribute_replmetadata_version(
            dn, "description")
        self.samdb.set_attribute_replmetadata_version(dn, "description",
                                                      version + 2)
        self.assertEqual(
            self.samdb.get_attribute_replmetadata_version(dn, "description"),
            version + 2)

    def test_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               credentials=self.creds,
                               lp=self.lp)

            self.samdb.transaction_start()

            dn = "cn=test_db_lock_user,cn=users," + str(basedn)
            self.samdb.add({
                "dn": dn,
                "objectclass": "user",
            })
            self.samdb.delete(dn)

            # Obtain a write lock
            self.samdb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            self.samdb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        # Release the locks
        for l in res:
            pass

        end = time.time()
        self.assertGreater(end - start, 1.9)

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               credentials=self.creds,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        dn = "cn=test_db_lock_user,cn=users," + str(basedn)
        self.samdb.add({
            "dn": dn,
            "objectclass": "user",
        })
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the parent releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            self.samdb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock3(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               credentials=self.creds,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")

        # This will end up in the top level db
        dn = "@DSDB_LOCK_TEST"
        self.samdb.add({"dn": dn})
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        self.assertGreater(end - start, 1.9)
        os.write(w1, b"prepared")

        # Drop the write lock
        self.samdb.transaction_cancel()

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)
        self.assertEqual(got_pid, pid)

    def _test_full_db_lock1(self, backend_path):
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del (self.samdb)
            gc.collect()

            backenddb = ldb.Ldb(backend_path)

            backenddb.transaction_start()

            backenddb.add({"dn": "@DSDB_LOCK_TEST"})
            backenddb.delete("@DSDB_LOCK_TEST")

            # Obtain a write lock
            backenddb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            backenddb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        end = time.time()
        self.assertGreater(end - start, 1.9)

        # Release the locks
        for l in res:
            pass

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)

    def test_full_db_lock1_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)

    def _test_full_db_lock2(self, backend_path):
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:

            # In the child, close the main DB, re-open
            del (self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               credentials=self.creds,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)
            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # In the parent, close the main DB, re-open just one DB
        del (self.samdb)
        gc.collect()
        backenddb = ldb.Ldb(backend_path)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        backenddb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        backenddb.add({"dn": "@DSDB_LOCK_TEST"})
        backenddb.delete("@DSDB_LOCK_TEST")
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        backenddb.transaction_prepare_commit()
        end = time.time()

        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            backenddb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)

    def test_full_db_lock2_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d", backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)

    def test_no_error_on_invalid_control(self):
        try:
            res = self.samdb.search(
                expression="cn=Administrator",
                scope=ldb.SCOPE_SUBTREE,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:0" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            self.fail("Should have not raised an exception")

    def test_error_on_invalid_critical_control(self):
        try:
            res = self.samdb.search(
                expression="cn=Administrator",
                scope=ldb.SCOPE_SUBTREE,
                attrs=["replPropertyMetaData"],
                controls=[
                    "local_oid:%s:1" %
                    dsdb.DSDB_CONTROL_INVALID_NOT_IMPLEMENTED
                ])
        except ldb.LdbError as e:
            if e[0] != ldb.ERR_UNSUPPORTED_CRITICAL_EXTENSION:
                self.fail(
                    "Got %s should have got ERR_UNSUPPORTED_CRITICAL_EXTENSION"
                    % e[1])
Exemple #32
0
def deprovision_schema(setup_path, names, lp, creds, reporter, ldif, msg, modify_mode=False):
    """Deprovision/unmodify schema using LDIF specified file, by reverting the
    modifications contained therein.

    :param setup_path: Path to the setup directory.
    :param names: provision names object.
    :param lp: Loadparm context
    :param creds: Credentials Context
    :param reporter: A progress reporter instance (subclass of AbstractProgressReporter)
    :param ldif: path to the LDIF file
    :param msg: reporter message
    :param modify_mode: whether entries are added or modified
    """

    session_info = system_session()
    db = SamDB(url=get_ldb_url(lp, creds, names), session_info=session_info,
               credentials=creds, lp=lp)

    db.transaction_start()

    try:
        reporter.reportNextStep(msg)

        ldif_content = read_and_sub_file(setup_path(ldif),
                                         {"FIRSTORG": names.firstorg,
                                          "FIRSTORGDN": names.firstorgdn,
                                          "CONFIGDN": names.configdn,
                                          "SCHEMADN": names.schemadn,
                                          "DOMAINDN": names.domaindn,
                                          "DOMAIN": names.domain,
                                          "DNSDOMAIN": names.dnsdomain,
                                          "NETBIOSNAME": names.netbiosname,
                                          "HOSTNAME": names.hostname
                                          })
        if modify_mode:
            lines = ldif_content.splitlines()
            keep_line = False
            entries = []
            current_entry = []
            entries.append(current_entry)
            for line in lines:
                skip_this_line = False
                if line.startswith("dn:") or line == "":
                    # current_entry.append("")
                    current_entry = []
                    entries.append(current_entry)
                    keep_line = True
                elif line.startswith("add:"):
                    keep_line = True
                    line = "delete:" + line[4:]
                elif line.startswith("replace:"):
                    keep_line = False
                elif line.startswith("#") or line.strip() == "":
                    skip_this_line = True

                if keep_line and not skip_this_line:
                    current_entry.append(line)

            entries.reverse()
            for entry in entries:
                ldif_content = "\n".join(entry)
                print ldif_content
                try:
                    db.modify_ldif(ldif_content)
                except:
                    pass
        else:
            lines = ldif_content.splitlines()
            lines.reverse()
            for line in lines:
                if line.startswith("dn:"):
                    db.delete(line[4:])

    except:
        db.transaction_cancel()
        raise

    db.transaction_commit()
def purge_computer_with_DC_objects(ucr, binddn, bindpw, computername):

    lp = LoadParm()
    lp.load('/etc/samba/smb.conf')

    samdb = SamDB(os.path.join(SAMBA_PRIVATE_DIR, "sam.ldb"),
                  session_info=system_session(lp),
                  lp=lp)

    backlink_attribute_list = [
        "serverReferenceBL", "frsComputerReferenceBL",
        "msDFSR-ComputerReferenceBL"
    ]
    msgs = samdb.search(
        base=ucr["samba4/ldap/base"],
        scope=samba.ldb.SCOPE_SUBTREE,
        expression="(&(objectClass=computer)(sAMAccountName=%s$))" %
        computername,
        attrs=backlink_attribute_list)
    if not msgs:
        print "Samba 4 computer account '%s' not found." % (computername, )
        sys.exit(1)

    answer = raw_input("Really remove %s from Samba 4? [y/N]: " % computername)
    if not answer.lower() in ('y', 'yes'):
        print "Ok, stopping as requested.\n"
        sys.exit(2)

    computer_obj = msgs[0]

    # Confirmation check
    answer = raw_input("If you are really sure type YES and hit enter: ")
    if not answer == 'YES':
        print "The answer was not 'YES', confirmation failed.\n"
        sys.exit(1)
    else:
        print "Ok, continuing as requested.\n"

    # Determine the NTDS_objectGUID
    NTDS_objectGUID = None
    if "serverReferenceBL" in computer_obj:
        msgs = samdb.search(base=computer_obj["serverReferenceBL"][0],
                            scope=samba.ldb.SCOPE_SUBTREE,
                            expression="(CN=NTDS Settings)",
                            attrs=["objectGUID"])
        if msgs and "objectGUID" in msgs[0]:
            NTDS_objectGUID = str(
                ndr_unpack(misc.GUID, msgs[0]["objectGUID"][0]))

    # Determine the Domain_GUID
    msgs = samdb.search(base=ucr["samba4/ldap/base"],
                        scope=samba.ldb.SCOPE_BASE,
                        attrs=["objectGUID"])
    if not msgs:
        print "Samba 4 Domain_GUID for base dn '%s' not found." % (
            ucr["samba4/ldap/base"], )
        sys.exit(1)
    Domain_GUID = str(ndr_unpack(misc.GUID, msgs[0]["objectGUID"][0]))

    # Build current site list
    msgs = samdb.search(base="CN=Configuration,%s" % ucr["samba4/ldap/base"],
                        scope=samba.ldb.SCOPE_SUBTREE,
                        expression="(objectClass=site)",
                        attrs=["cn"])
    site_list = [obj["cn"][0] for obj in msgs]

    # Remove Samba 4 DNS records
    purge_s4_dns_records(ucr, binddn, bindpw, computername, NTDS_objectGUID,
                         Domain_GUID, site_list)

    # remove objects from Samba 4 SAM database
    for backlink_attribute in backlink_attribute_list:
        if backlink_attribute in computer_obj:
            backlink_object = computer_obj[backlink_attribute][0]
            try:
                print "Removing %s from SAM database." % (backlink_object, )
                samdb.delete(backlink_object, ["tree_delete:0"])
            except:
                print >> sys.stderr, "Removal of Samba 4 %s objects %s from Samba 4 SAM database failed." % (
                    backlink_attribute,
                    backlink_object,
                )
                print traceback.format_exc()

    # Now delete the Samba 4 computer account and sub-objects
    # Cannot use tree_delete on isCriticalSystemObject, perform recursive delete like ldbdel code does it:
    msgs = samdb.search(base=computer_obj.dn,
                        scope=samba.ldb.SCOPE_SUBTREE,
                        attrs=["dn"])
    obj_dn_list = [obj.dn for obj in msgs]
    obj_dn_list.sort(key=len)
    obj_dn_list.reverse()
    for obj_dn in obj_dn_list:
        try:
            print "Removing %s from SAM database." % (obj_dn, )
            samdb.delete(obj_dn)
        except:
            print >> sys.stderr, "Removal of Samba 4 computer account object %s from Samba 4 SAM database failed." % (
                obj_dn, )
            print >> sys.stderr, traceback.format_exc()

    answer = raw_input("Really remove %s from UDM as well? [y/N]: " %
                       computername)
    if not answer.lower() in ('y', 'yes'):
        print "Ok, stopping as requested.\n"
        sys.exit(2)

    # Finally, for consistency remove S4 computer object from UDM
    purge_udm_computer(ucr, binddn, bindpw, computername)
Exemple #34
0
class AuditLogDsdbTests(AuditLogTestBase):

    def setUp(self):
        self.message_type = MSG_DSDB_LOG
        self.event_type   = DSDB_EVENT_NAME
        super(AuditLogDsdbTests, self).setUp()

        self.remoteAddress = os.environ["CLIENT_IP"]
        self.server_ip = os.environ["SERVER_IP"]

        host = "ldap://%s" % os.environ["SERVER"]
        self.ldb = SamDB(url=host,
                         session_info=system_session(),
                         credentials=self.get_credentials(),
                         lp=self.get_loadparm())
        self.server = os.environ["SERVER"]

        # Gets back the basedn
        self.base_dn = self.ldb.domain_dn()

        # Get the old "dSHeuristics" if it was set
        dsheuristics = self.ldb.get_dsheuristics()

        # Set the "dSHeuristics" to activate the correct "userPassword"
        # behaviour
        self.ldb.set_dsheuristics("000000001")

        # Reset the "dSHeuristics" as they were before
        self.addCleanup(self.ldb.set_dsheuristics, dsheuristics)

        # Get the old "minPwdAge"
        minPwdAge = self.ldb.get_minPwdAge()

        # Set it temporarily to "0"
        self.ldb.set_minPwdAge("0")
        self.base_dn = self.ldb.domain_dn()

        # Reset the "minPwdAge" as it was before
        self.addCleanup(self.ldb.set_minPwdAge, minPwdAge)

        # (Re)adds the test user USER_NAME with password USER_PASS
        delete_force(self.ldb, "cn=" + USER_NAME + ",cn=users," + self.base_dn)
        self.ldb.add({
            "dn": "cn=" + USER_NAME + ",cn=users," + self.base_dn,
            "objectclass": "user",
            "sAMAccountName": USER_NAME,
            "userPassword": USER_PASS
        })

    #
    # Discard the messages from the setup code
    #
    def discardSetupMessages(self, dn):
        self.waitForMessages(2, dn=dn)
        self.discardMessages()

    def tearDown(self):
        self.discardMessages()
        super(AuditLogDsdbTests, self).tearDown()

    def haveExpectedTxn(self, expected):
        if self.context["txnMessage"] is not None:
            txn = self.context["txnMessage"]["dsdbTransaction"]
            if txn["transactionId"] == expected:
                return True
        return False

    def waitForTransaction(self, expected, connection=None):
        """Wait for a transaction message to arrive
        The connection is passed through to keep the connection alive
        until all the logging messages have been received.
        """

        self.connection = connection

        start_time = time.time()
        while not self.haveExpectedTxn(expected):
            self.msg_ctx.loop_once(0.1)
            if time.time() - start_time > 1:
                self.connection = None
                return ""

        self.connection = None
        return self.context["txnMessage"]

    def test_net_change_password(self):

        dn = "CN=" + USER_NAME + ",CN=Users," + self.base_dn
        self.discardSetupMessages(dn)

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server)
        password = "******"

        net.change_password(newpassword=password.encode('utf-8'),
                            username=USER_NAME,
                            oldpassword=USER_PASS)

        messages = self.waitForMessages(1, net, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        # We skip the check for self.get_service_description() as this
        # is subject to a race between smbd and the s4 rpc_server code
        # as to which will set the description as it is DCE/RPC over SMB

        self.assertTrue(self.is_guid(audit["transactionId"]))

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["clearTextPassword"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertTrue(actions[0]["redacted"])
        self.assertEquals("replace", actions[0]["action"])

    def test_net_set_password(self):

        dn = "CN=" + USER_NAME + ",CN=Users," + self.base_dn
        self.discardSetupMessages(dn)

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server)
        password = "******"
        domain = lp.get("workgroup")

        net.set_password(newpassword=password.encode('utf-8'),
                         account_name=USER_NAME,
                         domain_name=domain)
        messages = self.waitForMessages(1, net, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertEquals(dn, audit["dn"])
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        # We skip the check for self.get_service_description() as this
        # is subject to a race between smbd and the s4 rpc_server code
        # as to which will set the description as it is DCE/RPC over SMB

        self.assertTrue(self.is_guid(audit["transactionId"]))

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["clearTextPassword"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertTrue(actions[0]["redacted"])
        self.assertEquals("replace", actions[0]["action"])

    def test_ldap_change_password(self):

        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        new_password = samba.generate_random_password(32, 32)
        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.ldb.modify_ldif(
            "dn: " + dn + "\n" +
            "changetype: modify\n" +
            "delete: userPassword\n" +
            "userPassword: "******"\n" +
            "add: userPassword\n" +
            "userPassword: "******"\n")

        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertEquals(dn, audit["dn"])
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["userPassword"]["actions"]
        self.assertEquals(2, len(actions))
        self.assertTrue(actions[0]["redacted"])
        self.assertEquals("delete", actions[0]["action"])
        self.assertTrue(actions[1]["redacted"])
        self.assertEquals("add", actions[1]["action"])

    def test_ldap_replace_password(self):

        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        new_password = samba.generate_random_password(32, 32)
        self.ldb.modify_ldif(
            "dn: " + dn + "\n" +
            "changetype: modify\n" +
            "replace: userPassword\n" +
            "userPassword: "******"\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")
        self.assertTrue(self.is_guid(audit["transactionId"]))

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["userPassword"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertTrue(actions[0]["redacted"])
        self.assertEquals("replace", actions[0]["action"])

    def test_ldap_add_user(self):

        # The setup code adds a user, so we check for the dsdb events
        # generated by it.
        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        messages = self.waitForMessages(2, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(2,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[1]["dsdbChange"]
        self.assertEquals("Add", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertEquals(dn, audit["dn"])
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")
        self.assertTrue(self.is_guid(audit["sessionId"]))
        self.assertTrue(self.is_guid(audit["transactionId"]))

        attributes = audit["attributes"]
        self.assertEquals(3, len(attributes))

        actions = attributes["objectclass"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        self.assertEquals(1, len(actions[0]["values"]))
        self.assertEquals("user", actions[0]["values"][0]["value"])

        actions = attributes["sAMAccountName"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        self.assertEquals(1, len(actions[0]["values"]))
        self.assertEquals(USER_NAME, actions[0]["values"][0]["value"])

        actions = attributes["userPassword"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        self.assertTrue(actions[0]["redacted"])

    def test_samdb_delete_user(self):

        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        self.ldb.deleteuser(USER_NAME)

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Delete", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        self.assertEquals(0, audit["statusCode"])
        self.assertEquals("Success", audit["status"])
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        transactionId = audit["transactionId"]
        message = self.waitForTransaction(transactionId)
        audit = message["dsdbTransaction"]
        self.assertEquals("commit", audit["action"])
        self.assertTrue(self.is_guid(audit["transactionId"]))
        self.assertTrue(audit["duration"] > 0)

    def test_samdb_delete_non_existent_dn(self):

        DOES_NOT_EXIST = "doesNotExist"
        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        dn = "cn=" + DOES_NOT_EXIST + ",cn=users," + self.base_dn
        try:
            self.ldb.delete(dn)
            self.fail("Exception not thrown")
        except Exception:
            pass

        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Delete", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertEquals(ERR_NO_SUCH_OBJECT, audit["statusCode"])
        self.assertEquals("No such object", audit["status"])
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        transactionId = audit["transactionId"]
        message = self.waitForTransaction(transactionId)
        audit = message["dsdbTransaction"]
        self.assertEquals("rollback", audit["action"])
        self.assertTrue(self.is_guid(audit["transactionId"]))
        self.assertTrue(audit["duration"] > 0)

    def test_create_and_delete_secret_over_lsa(self):

        dn = "cn=Test Secret,CN=System," + self.base_dn
        self.discardSetupMessages(dn)

        creds = self.insta_creds(template=self.get_credentials())
        lsa_conn = lsa.lsarpc(
            "ncacn_np:%s" % self.server,
            self.get_loadparm(),
            creds)
        lsa_handle = lsa_conn.OpenPolicy2(
            system_name="\\",
            attr=lsa.ObjectAttribute(),
            access_mask=security.SEC_FLAG_MAXIMUM_ALLOWED)
        secret_name = lsa.String()
        secret_name.string = "G$Test"
        lsa_conn.CreateSecret(
            handle=lsa_handle,
            name=secret_name,
            access_mask=security.SEC_FLAG_MAXIMUM_ALLOWED)

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Add", audit["operation"])
        self.assertTrue(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])

        # We skip the check for self.get_service_description() as this
        # is subject to a race between smbd and the s4 rpc_server code
        # as to which will set the description as it is DCE/RPC over SMB

        attributes = audit["attributes"]
        self.assertEquals(2, len(attributes))

        object_class = attributes["objectClass"]
        self.assertEquals(1, len(object_class["actions"]))
        action = object_class["actions"][0]
        self.assertEquals("add", action["action"])
        values = action["values"]
        self.assertEquals(1, len(values))
        self.assertEquals("secret", values[0]["value"])

        cn = attributes["cn"]
        self.assertEquals(1, len(cn["actions"]))
        action = cn["actions"][0]
        self.assertEquals("add", action["action"])
        values = action["values"]
        self.assertEquals(1, len(values))
        self.assertEquals("Test Secret", values[0]["value"])

        #
        # Now delete the secret.
        self.discardMessages()
        h = lsa_conn.OpenSecret(
            handle=lsa_handle,
            name=secret_name,
            access_mask=security.SEC_FLAG_MAXIMUM_ALLOWED)

        lsa_conn.DeleteObject(h)
        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")

        dn = "cn=Test Secret,CN=System," + self.base_dn
        audit = messages[0]["dsdbChange"]
        self.assertEquals("Delete", audit["operation"])
        self.assertTrue(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])

        # We skip the check for self.get_service_description() as this
        # is subject to a race between smbd and the s4 rpc_server code
        # as to which will set the description as it is DCE/RPC over SMB


    def test_modify(self):

        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        #
        # Add an attribute value
        #
        self.ldb.modify_ldif(
            "dn: " + dn + "\n" +
            "changetype: modify\n" +
            "add: carLicense\n" +
            "carLicense: license-01\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertEquals(dn, audit["dn"])
        self.assertRegexpMatches(audit["remoteAddress"],
                                 self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(1, len(values))
        self.assertEquals("license-01", values[0]["value"])

        #
        # Add an another value to the attribute
        #
        self.discardMessages()
        self.ldb.modify_ldif(
            "dn: " + dn + "\n" +
            "changetype: modify\n" +
            "add: carLicense\n" +
            "carLicense: license-02\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        attributes = messages[0]["dsdbChange"]["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(1, len(values))
        self.assertEquals("license-02", values[0]["value"])

        #
        # Add an another two values to the attribute
        #
        self.discardMessages()
        self.ldb.modify_ldif(
            "dn: " + dn + "\n" +
            "changetype: modify\n" +
            "add: carLicense\n" +
            "carLicense: license-03\n" +
            "carLicense: license-04\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        attributes = messages[0]["dsdbChange"]["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(2, len(values))
        self.assertEquals("license-03", values[0]["value"])
        self.assertEquals("license-04", values[1]["value"])

        #
        # delete two values to the attribute
        #
        self.discardMessages()
        self.ldb.modify_ldif(
            "dn: " + dn + "\n" +
            "changetype: delete\n" +
            "delete: carLicense\n" +
            "carLicense: license-03\n" +
            "carLicense: license-04\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        attributes = messages[0]["dsdbChange"]["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("delete", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(2, len(values))
        self.assertEquals("license-03", values[0]["value"])
        self.assertEquals("license-04", values[1]["value"])

        #
        # replace two values to the attribute
        #
        self.discardMessages()
        self.ldb.modify_ldif(
            "dn: " + dn + "\n" +
            "changetype: delete\n" +
            "replace: carLicense\n" +
            "carLicense: license-05\n" +
            "carLicense: license-06\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1,
                          len(messages),
                          "Did not receive the expected number of messages")
        attributes = messages[0]["dsdbChange"]["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("replace", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(2, len(values))
        self.assertEquals("license-05", values[0]["value"])
        self.assertEquals("license-06", values[1]["value"])
Exemple #35
0
def deprovision_schema(setup_path,
                       names,
                       lp,
                       creds,
                       reporter,
                       ldif,
                       msg,
                       modify_mode=False):
    """Deprovision/unmodify schema using LDIF specified file, by reverting the
    modifications contained therein.

    :param setup_path: Path to the setup directory.
    :param names: provision names object.
    :param lp: Loadparm context
    :param creds: Credentials Context
    :param reporter: A progress reporter instance (subclass of AbstractProgressReporter)
    :param ldif: path to the LDIF file
    :param msg: reporter message
    :param modify_mode: whether entries are added or modified
    """

    session_info = system_session()
    db = SamDB(url=get_ldb_url(lp, creds, names),
               session_info=session_info,
               credentials=creds,
               lp=lp)

    db.transaction_start()

    try:
        reporter.reportNextStep(msg)

        ldif_content = read_and_sub_file(
            setup_path(ldif), {
                "FIRSTORG": names.firstorg,
                "FIRSTORGDN": names.firstorgdn,
                "CONFIGDN": names.configdn,
                "SCHEMADN": names.schemadn,
                "DOMAINDN": names.domaindn,
                "DOMAIN": names.domain,
                "DNSDOMAIN": names.dnsdomain,
                "NETBIOSNAME": names.netbiosname,
                "HOSTNAME": names.hostname
            })
        if modify_mode:
            lines = ldif_content.splitlines()
            keep_line = False
            entries = []
            current_entry = []
            entries.append(current_entry)
            for line in lines:
                skip_this_line = False
                if line.startswith("dn:") or line == "":
                    # current_entry.append("")
                    current_entry = []
                    entries.append(current_entry)
                    keep_line = True
                elif line.startswith("add:"):
                    keep_line = True
                    line = "delete:" + line[4:]
                elif line.startswith("replace:"):
                    keep_line = False
                elif line.startswith("#") or line.strip() == "":
                    skip_this_line = True

                if keep_line and not skip_this_line:
                    current_entry.append(line)

            entries.reverse()
            for entry in entries:
                ldif_content = "\n".join(entry)
                print ldif_content
                try:
                    db.modify_ldif(ldif_content)
                except:
                    pass
        else:
            lines = ldif_content.splitlines()
            lines.reverse()
            for line in lines:
                if line.startswith("dn:"):
                    db.delete(line[4:])

    except:
        db.transaction_cancel()
        raise

    db.transaction_commit()
Exemple #36
0
class BaseSortTests(samba.tests.TestCase):
    avoid_tricky_sort = False
    maxDiff = 2000

    def create_user(self,
                    i,
                    n,
                    prefix='sorttest',
                    suffix='',
                    attrs=None,
                    tricky=False):
        name = "%s%d%s" % (prefix, i, suffix)
        user = {
            'cn':
            name,
            "objectclass":
            "user",
            'givenName':
            "abcdefghijklmnopqrstuvwxyz"[i % 26],
            "roomNumber":
            "%sb\x00c" % (n - i),
            # with python3 re.sub(r'[^\w,.]', repl, string) doesn't
            # work as expected with unicode as value for carLicense
            "carLicense":
            "XXXXXXXXX" if self.avoid_tricky_sort else "后来经",
            "employeeNumber":
            "%s%sx" % (abs(i * (99 - i)), '\n' * (i & 255)),
            "accountExpires":
            "%s" % (10**9 + 1000000 * i),
            "msTSExpireDate4":
            "19%02d0101010000.0Z" % (i % 100),
            "flags":
            str(i * (n - i)),
            "serialNumber":
            "abc %s%s%s" % (
                'AaBb |-/'[i & 7],
                ' 3z}'[i & 3],
                '"@'[i & 1],
            ),
            "comment":
            "Favourite colour is %d" % (n % (i + 1)),
        }

        if self.avoid_tricky_sort:
            # We are not even going to try passing tests that assume
            # some kind of Unicode awareness.
            for k, v in user.items():
                user[k] = re.sub(r'[^\w,.]', 'X', v)
        else:
            # Add some even trickier ones!
            fiendish_index = i % len(FIENDISH_TESTS)
            user.update({
                # Sort doesn't look past a NUL byte.
                "photo":
                "\x00%d" % (n - i),
                "audio":
                "%sn octet string %s%s ♫♬\x00lalala" %
                ('Aa'[i & 1], chr(i & 255), i),
                "displayNamePrintable":
                "%d\x00%c" % (i, i & 255),
                "adminDisplayName":
                "%d\x00b" % (n - i),
                "title":
                "%d%sb" % (n - i, '\x00' * i),

                # Names that vary only in case. Windows returns
                # equivalent addresses in the order they were put
                # in ('a st', 'A st',...). We don't check that.
                "street":
                "%s st" % (chr(65 | (i & 14) | ((i & 1) * 32))),
                "streetAddress":
                FIENDISH_TESTS[fiendish_index],
                "postalAddress":
                FIENDISH_TESTS[-fiendish_index],
            })

        if attrs is not None:
            user.update(attrs)

        user['dn'] = "cn=%s,%s" % (user['cn'], self.ou)

        self.users.append(user)
        self.ldb.add(user)
        return user

    def setUp(self):
        super(BaseSortTests, self).setUp()
        self.ldb = SamDB(host,
                         credentials=creds,
                         session_info=system_session(lp),
                         lp=lp)

        self.base_dn = self.ldb.domain_dn()
        self.ou = "ou=sort,%s" % self.base_dn
        if False:
            try:
                self.ldb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError as e:
                print("tried deleting %s, got error %s" % (self.ou, e))

        self.ldb.add({"dn": self.ou, "objectclass": "organizationalUnit"})
        self.users = []
        n = opts.elements
        for i in range(n):
            self.create_user(i, n)

        attrs = set(self.users[0].keys()) - set(['objectclass', 'dn'])
        self.binary_sorted_keys = attrs.intersection([
            'audio', 'photo', "msTSExpireDate4", 'serialNumber',
            "displayNamePrintable"
        ])

        self.numeric_sorted_keys = attrs.intersection(
            ['flags', 'accountExpires'])

        self.timestamp_keys = attrs.intersection(['msTSExpireDate4'])

        self.int64_keys = set(['accountExpires'])

        self.locale_sorted_keys = [
            x for x in attrs
            if x not in (self.binary_sorted_keys | self.numeric_sorted_keys)
        ]

        self.expected_results = {}
        self.expected_results_binary = {}

        for k in self.binary_sorted_keys:
            forward = sorted((x[k] for x in self.users))
            reverse = list(reversed(forward))
            self.expected_results_binary[k] = (forward, reverse)

        # FYI: Expected result data was generated from the old
        # code that was manually sorting (while executing with
        # python2)
        # The resulting data was injected into the data file with
        # code similar to:
        #
        # for k in self.expected_results:
        #     f.write("%s = %s\n" % (k,  repr(self.expected_results[k][0])))

        f = open(self.results_file, "r")
        for line in f:
            if len(line.split('=', 1)) == 2:
                key = line.split('=', 1)[0].strip()
                value = line.split('=', 1)[1].strip()
                if value.startswith('['):
                    import ast
                    fwd_list = ast.literal_eval(value)
                    rev_list = list(reversed(fwd_list))
                    self.expected_results[key] = (fwd_list, rev_list)
        f.close()

    def tearDown(self):
        super(BaseSortTests, self).tearDown()
        self.ldb.delete(self.ou, ['tree_delete:1'])

    def _test_server_sort_default(self):
        attrs = self.locale_sorted_keys

        for attr in attrs:
            for rev in (0, 1):
                res = self.ldb.search(
                    self.ou,
                    scope=ldb.SCOPE_ONELEVEL,
                    attrs=[attr],
                    controls=["server_sort:1:%d:%s" % (rev, attr)])
                self.assertEqual(len(res), len(self.users))

                expected_order = self.expected_results[attr][rev]
                received_order = [norm(x[attr][0]) for x in res]
                if expected_order != received_order:
                    print(attr, ['forward', 'reverse'][rev])
                    print("expected", expected_order)
                    print("received", received_order)
                    print("unnormalised:", [x[attr][0] for x in res])
                    print("unnormalised: «%s»" %
                          '»  «'.join(str(x[attr][0]) for x in res))
                self.assertEquals(expected_order, received_order)

    def _test_server_sort_binary(self):
        for attr in self.binary_sorted_keys:
            for rev in (0, 1):
                res = self.ldb.search(
                    self.ou,
                    scope=ldb.SCOPE_ONELEVEL,
                    attrs=[attr],
                    controls=["server_sort:1:%d:%s" % (rev, attr)])

                self.assertEqual(len(res), len(self.users))
                expected_order = self.expected_results_binary[attr][rev]
                received_order = [str(x[attr][0]) for x in res]
                if expected_order != received_order:
                    print(attr)
                    print(expected_order)
                    print(received_order)
                self.assertEquals(expected_order, received_order)

    def _test_server_sort_us_english(self):
        # Windows doesn't support many matching rules, but does allow
        # the locale specific sorts -- if it has the locale installed.
        # The most reliable locale is the default US English, which
        # won't change the sort order.

        for lang, oid in [
            ('en_US', '1.2.840.113556.1.4.1499'),
        ]:

            for attr in self.locale_sorted_keys:
                for rev in (0, 1):
                    res = self.ldb.search(
                        self.ou,
                        scope=ldb.SCOPE_ONELEVEL,
                        attrs=[attr],
                        controls=["server_sort:1:%d:%s:%s" % (rev, attr, oid)])

                    self.assertTrue(len(res) == len(self.users))
                    expected_order = self.expected_results[attr][rev]
                    received_order = [norm(x[attr][0]) for x in res]
                    if expected_order != received_order:
                        print(attr, lang)
                        print(['forward', 'reverse'][rev])
                        print("expected: ", expected_order)
                        print("received: ", received_order)
                        print("unnormalised:", [x[attr][0] for x in res])
                        print("unnormalised: «%s»" %
                              '»  «'.join(str(x[attr][0]) for x in res))

                    self.assertEquals(expected_order, received_order)

    def _test_server_sort_different_attr(self):
        def cmp_locale(a, b):
            return locale.strcoll(a[0], b[0])

        def cmp_binary(a, b):
            return cmp_fn(a[0], b[0])

        def cmp_numeric(a, b):
            return cmp_fn(int(a[0]), int(b[0]))

        # For testing simplicity, the attributes in here need to be
        # unique for each user. Otherwise there are multiple possible
        # valid answers.
        sort_functions = {
            'cn': cmp_binary,
            "employeeNumber": cmp_locale,
            "accountExpires": cmp_numeric,
            "msTSExpireDate4": cmp_binary
        }
        attrs = list(sort_functions.keys())
        attr_pairs = zip(attrs, attrs[1:] + attrs[:1])

        for sort_attr, result_attr in attr_pairs:
            forward = sorted(((norm(x[sort_attr]), norm(x[result_attr]))
                              for x in self.users),
                             key=cmp_to_key_fn(sort_functions[sort_attr]))
            reverse = list(reversed(forward))

            for rev in (0, 1):
                res = self.ldb.search(
                    self.ou,
                    scope=ldb.SCOPE_ONELEVEL,
                    attrs=[result_attr],
                    controls=["server_sort:1:%d:%s" % (rev, sort_attr)])
                self.assertEqual(len(res), len(self.users))
                pairs = (forward, reverse)[rev]

                expected_order = [x[1] for x in pairs]
                received_order = [norm(x[result_attr][0]) for x in res]

                if expected_order != received_order:
                    print(sort_attr, result_attr, ['forward', 'reverse'][rev])
                    print("expected", expected_order)
                    print("received", received_order)
                    print("unnormalised:", [x[result_attr][0] for x in res])
                    print("unnormalised: «%s»" %
                          '»  «'.join(str(x[result_attr][0]) for x in res))
                    print("pairs:", pairs)
                    # There are bugs in Windows that we don't want (or
                    # know how) to replicate regarding timestamp sorting.
                    # Let's remind ourselves.
                    if result_attr == "msTSExpireDate4":
                        print('-' * 72)
                        print("This test fails against Windows with the "
                              "default number of elements (33).")
                        print("Try with --elements=27 (or similar).")
                        print('-' * 72)

                self.assertEquals(expected_order, received_order)
                for x in res:
                    if sort_attr in x:
                        self.fail('the search for %s should not return %s' %
                                  (result_attr, sort_attr))
Exemple #37
0
class DsdbLockTestCase(SamDBTestCase):
    def test_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del(self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               lp=self.lp)

            self.samdb.transaction_start()

            dn = "cn=test_db_lock_user,cn=users," + str(basedn)
            self.samdb.add({
                 "dn": dn,
                 "objectclass": "user",
            })
            self.samdb.delete(dn)

            # Obtain a write lock
            self.samdb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            self.samdb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        # Release the locks
        for l in res:
            pass

        end = time.time()
        self.assertGreater(end - start, 1.9)

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del(self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        dn = "cn=test_db_lock_user,cn=users," + str(basedn)
        self.samdb.add({
             "dn": dn,
             "objectclass": "user",
        })
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the parent releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            self.samdb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_db_lock3(self):
        basedn = self.samdb.get_default_basedn()
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open
            del(self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)

            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        self.samdb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")

        # This will end up in the top level db
        dn = "@DSDB_LOCK_TEST"
        self.samdb.add({
             "dn": dn})
        self.samdb.delete(dn)
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        self.samdb.transaction_prepare_commit()
        end = time.time()
        self.assertGreater(end - start, 1.9)
        os.write(w1, b"prepared")

        # Drop the write lock
        self.samdb.transaction_cancel()

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)
        self.assertEqual(got_pid, pid)


    def _test_full_db_lock1(self, backend_path):
        (r1, w1) = os.pipe()

        pid = os.fork()
        if pid == 0:
            # In the child, close the main DB, re-open just one DB
            del(self.samdb)
            gc.collect()

            backenddb = ldb.Ldb(backend_path)


            backenddb.transaction_start()

            backenddb.add({"dn":"@DSDB_LOCK_TEST"})
            backenddb.delete("@DSDB_LOCK_TEST")

            # Obtain a write lock
            backenddb.transaction_prepare_commit()
            os.write(w1, b"prepared")
            time.sleep(2)

            # Drop the write lock
            backenddb.transaction_cancel()
            os._exit(0)

        self.assertEqual(os.read(r1, 8), b"prepared")

        start = time.time()

        # We need to hold this iterator open to hold the all-record lock.
        res = self.samdb.search_iterator()

        # This should take at least 2 seconds because the transaction
        # has a write lock on one backend db open

        end = time.time()
        self.assertGreater(end - start, 1.9)

        # Release the locks
        for l in res:
            pass

        (got_pid, status) = os.waitpid(pid, 0)
        self.assertEqual(got_pid, pid)
        self.assertTrue(os.WIFEXITED(status))
        self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock1(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d",
                                       backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)


    def test_full_db_lock1_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d",
                                       backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock1(backend_path)


    def _test_full_db_lock2(self, backend_path):
        (r1, w1) = os.pipe()
        (r2, w2) = os.pipe()

        pid = os.fork()
        if pid == 0:

            # In the child, close the main DB, re-open
            del(self.samdb)
            gc.collect()
            self.samdb = SamDB(session_info=self.session,
                               lp=self.lp)

            # We need to hold this iterator open to hold the all-record lock.
            res = self.samdb.search_iterator()

            os.write(w2, b"start")
            if (os.read(r1, 7) != b"started"):
                os._exit(1)
            os.write(w2, b"add")
            if (os.read(r1, 5) != b"added"):
                os._exit(2)

            # Wait 2 seconds to block prepare_commit() in the child.
            os.write(w2, b"prepare")
            time.sleep(2)

            # Release the locks
            for l in res:
                pass

            if (os.read(r1, 8) != b"prepared"):
                os._exit(3)

            os._exit(0)

        # In the parent, close the main DB, re-open just one DB
        del(self.samdb)
        gc.collect()
        backenddb = ldb.Ldb(backend_path)

        # We can start the transaction during the search
        # because both just grab the all-record read lock.
        self.assertEqual(os.read(r2, 5), b"start")
        backenddb.transaction_start()
        os.write(w1, b"started")

        self.assertEqual(os.read(r2, 3), b"add")
        backenddb.add({"dn":"@DSDB_LOCK_TEST"})
        backenddb.delete("@DSDB_LOCK_TEST")
        os.write(w1, b"added")

        # Obtain a write lock, this will block until
        # the child releases the read lock.
        self.assertEqual(os.read(r2, 7), b"prepare")
        start = time.time()
        backenddb.transaction_prepare_commit()
        end = time.time()

        try:
            self.assertGreater(end - start, 1.9)
        except:
            raise
        finally:
            os.write(w1, b"prepared")

            # Drop the write lock
            backenddb.transaction_cancel()

            (got_pid, status) = os.waitpid(pid, 0)
            self.assertEqual(got_pid, pid)
            self.assertTrue(os.WIFEXITED(status))
            self.assertEqual(os.WEXITSTATUS(status), 0)

    def test_full_db_lock2(self):
        basedn = self.samdb.get_default_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d",
                                       backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)

    def test_full_db_lock2_config(self):
        basedn = self.samdb.get_config_basedn()
        backend_filename = "%s.ldb" % basedn.get_casefold()
        backend_subpath = os.path.join("sam.ldb.d",
                                       backend_filename)
        backend_path = self.lp.private_path(backend_subpath)
        self._test_full_db_lock2(backend_path)
Exemple #38
0
class ConfidentialAttrCommon(samba.tests.TestCase):
    def setUp(self):
        super(ConfidentialAttrCommon, self).setUp()

        self.ldb_admin = SamDB(ldaphost,
                               credentials=creds,
                               session_info=system_session(lp),
                               lp=lp)
        self.user_pass = "******"
        self.base_dn = self.ldb_admin.domain_dn()
        self.schema_dn = self.ldb_admin.get_schema_basedn()
        self.sd_utils = sd_utils.SDUtils(self.ldb_admin)

        # the tests work by setting the 'Confidential' bit in the searchFlags
        # for an existing schema attribute. This only works against Windows if
        # the systemFlags does not have FLAG_SCHEMA_BASE_OBJECT set for the
        # schema attribute being modified. There are only a few attributes that
        # meet this criteria (most of which only apply to 'user' objects)
        self.conf_attr = "homePostalAddress"
        attr_cn = "CN=Address-Home"
        # schemaIdGuid for homePostalAddress (used for ACE tests)
        self.conf_attr_guid = "16775781-47f3-11d1-a9c3-0000f80367c1"
        self.conf_attr_sec_guid = "77b5b886-944a-11d1-aebd-0000f80367c1"
        self.attr_dn = "{},{}".format(attr_cn, self.schema_dn)

        userou = "OU=conf-attr-test"
        self.ou = "{},{}".format(userou, self.base_dn)
        self.ldb_admin.create_ou(self.ou)

        # use a common username prefix, so we can use sAMAccountName=CATC-* as
        # a search filter to only return the users we're interested in
        self.user_prefix = "catc-"

        # add a test object with this attribute set
        self.conf_value = "abcdef"
        self.conf_user = "******".format(self.user_prefix)
        self.ldb_admin.newuser(self.conf_user, self.user_pass, userou=userou)
        self.conf_dn = self.get_user_dn(self.conf_user)
        self.add_attr(self.conf_dn, self.conf_attr, self.conf_value)

        # add a sneaky user that will try to steal our secrets
        self.user = "******".format(self.user_prefix)
        self.ldb_admin.newuser(self.user, self.user_pass, userou=userou)
        self.ldb_user = self.get_ldb_connection(self.user, self.user_pass)

        self.all_users = [self.user, self.conf_user]

        # add some other users that also have confidential attributes, so we can
        # check we don't disclose their details, particularly in '!' searches
        for i in range(1, 3):
            username = "******".format(self.user_prefix, i)
            self.ldb_admin.newuser(username, self.user_pass, userou=userou)
            userdn = self.get_user_dn(username)
            self.add_attr(userdn, self.conf_attr, "xyz{}".format(i))
            self.all_users.append(username)

        # there are 4 users in the OU, plus the OU itself
        self.test_dn = self.ou
        self.total_objects = len(self.all_users) + 1
        self.objects_with_attr = 3

        # sanity-check the flag is not already set (this'll cause problems if
        # previous test run didn't clean up properly)
        search_flags = self.get_attr_search_flags(self.attr_dn)
        self.assertTrue(
            int(search_flags) & SEARCH_FLAG_CONFIDENTIAL == 0,
            "{} searchFlags already {}".format(self.conf_attr, search_flags))

    def tearDown(self):
        super(ConfidentialAttrCommon, self).tearDown()
        self.ldb_admin.delete(self.ou, ["tree_delete:1"])

    def add_attr(self, dn, attr, value):
        m = Message()
        m.dn = Dn(self.ldb_admin, dn)
        m[attr] = MessageElement(value, FLAG_MOD_ADD, attr)
        self.ldb_admin.modify(m)

    def set_attr_search_flags(self, attr_dn, flags):
        """Modifies the searchFlags for an object in the schema"""
        m = Message()
        m.dn = Dn(self.ldb_admin, attr_dn)
        m['searchFlags'] = MessageElement(flags, FLAG_MOD_REPLACE,
                                          'searchFlags')
        self.ldb_admin.modify(m)

        # note we have to update the schema for this change to take effect (on
        # Windows, at least)
        self.ldb_admin.set_schema_update_now()

    def get_attr_search_flags(self, attr_dn):
        """Marks the attribute under test as being confidential"""
        res = self.ldb_admin.search(attr_dn,
                                    scope=SCOPE_BASE,
                                    attrs=['searchFlags'])
        return res[0]['searchFlags'][0]

    def make_attr_confidential(self):
        """Marks the attribute under test as being confidential"""

        # work out the original 'searchFlags' value before we overwrite it
        old_value = self.get_attr_search_flags(self.attr_dn)

        self.set_attr_search_flags(self.attr_dn, str(SEARCH_FLAG_CONFIDENTIAL))

        # reset the value after the test completes
        self.addCleanup(self.set_attr_search_flags, self.attr_dn, old_value)

    # The behaviour of the DC can differ in some cases, depending on whether
    # we're talking to a Windows DC or a Samba DC
    def guess_dc_mode(self):
        # if we're in selftest, we can be pretty sure it's a Samba DC
        if os.environ.get('SAMBA_SELFTEST') == '1':
            return DC_MODE_RETURN_NONE

        searches = self.get_negative_match_all_searches()
        res = self.ldb_user.search(self.test_dn,
                                   expression=searches[0],
                                   scope=SCOPE_SUBTREE)

        # we default to DC_MODE_RETURN_NONE (samba).Update this if it
        # looks like we're talking to a Windows DC
        if len(res) == self.total_objects:
            return DC_MODE_RETURN_ALL

        # otherwise assume samba DC behaviour
        return DC_MODE_RETURN_NONE

    def get_user_dn(self, name):
        return "CN={},{}".format(name, self.ou)

    def get_user_sid_string(self, username):
        user_dn = self.get_user_dn(username)
        user_sid = self.sd_utils.get_object_sid(user_dn)
        return str(user_sid)

    def get_ldb_connection(self, target_username, target_password):
        creds_tmp = Credentials()
        creds_tmp.set_username(target_username)
        creds_tmp.set_password(target_password)
        creds_tmp.set_domain(creds.get_domain())
        creds_tmp.set_realm(creds.get_realm())
        creds_tmp.set_workstation(creds.get_workstation())
        features = creds_tmp.get_gensec_features() | gensec.FEATURE_SEAL
        creds_tmp.set_gensec_features(features)
        creds_tmp.set_kerberos_state(DONT_USE_KERBEROS)
        ldb_target = SamDB(url=ldaphost, credentials=creds_tmp, lp=lp)
        return ldb_target

    def assert_not_in_result(self, res, exclude_dn):
        for msg in res:
            self.assertNotEqual(msg.dn, exclude_dn,
                                "Search revealed object {}".format(exclude_dn))

    def assert_search_result(self, expected_num, expr, samdb):

        # try asking for different attributes back: None/all, the confidential
        # attribute itself, and a random unrelated attribute
        attr_filters = [None, ["*"], [self.conf_attr], ['name']]
        for attr in attr_filters:
            res = samdb.search(self.test_dn,
                               expression=expr,
                               scope=SCOPE_SUBTREE,
                               attrs=attr)
            self.assertTrue(
                len(res) == expected_num,
                "%u results, not %u for search %s, attr %s" %
                (len(res), expected_num, expr, str(attr)))

    # return a selection of searches that match exactly against the test object
    def get_exact_match_searches(self):
        first_char = self.conf_value[:1]
        last_char = self.conf_value[-1:]
        test_attr = self.conf_attr

        searches = [
            # search for the attribute using a sub-string wildcard
            # (which could reveal the attribute's actual value)
            "({}={}*)".format(test_attr, first_char),
            "({}=*{})".format(test_attr, last_char),

            # sanity-check equality against an exact match on value
            "({}={})".format(test_attr, self.conf_value),

            # '~=' searches don't work against Samba
            # sanity-check an approx search against an exact match on value
            # "({}~={})".format(test_attr, self.conf_value),

            # check wildcard in an AND search...
            "(&({}={}*)(objectclass=*))".format(test_attr, first_char),

            # ...an OR search (against another term that will never match)
            "(|({}={}*)(objectclass=banana))".format(test_attr, first_char)
        ]

        return searches

    # return searches that match any object with the attribute under test
    def get_match_all_searches(self):
        searches = [
            # check a full wildcard against the confidential attribute
            # (which could reveal the attribute's presence/absence)
            "({}=*)".format(self.conf_attr),

            # check wildcard in an AND search...
            "(&(objectclass=*)({}=*))".format(self.conf_attr),

            # ...an OR search (against another term that will never match)
            "(|(objectclass=banana)({}=*))".format(self.conf_attr),

            # check <=, and >= expressions that would normally find a match
            "({}>=0)".format(self.conf_attr),
            "({}<=ZZZZZZZZZZZ)".format(self.conf_attr)
        ]

        return searches

    def assert_conf_attr_searches(self, has_rights_to=0, samdb=None):
        """Check searches against the attribute under test work as expected"""

        if samdb is None:
            samdb = self.ldb_user

        if has_rights_to == "all":
            has_rights_to = self.objects_with_attr

        # these first few searches we just expect to match against the one
        # object under test that we're trying to guess the value of
        expected_num = 1 if has_rights_to > 0 else 0
        for search in self.get_exact_match_searches():
            self.assert_search_result(expected_num, search, samdb)

        # these next searches will match any objects we have rights to see
        expected_num = has_rights_to
        for search in self.get_match_all_searches():
            self.assert_search_result(expected_num, search, samdb)

    # The following are double negative searches (i.e. NOT non-matching-
    # condition) which will therefore match ALL objects, including the test
    # object(s).
    def get_negative_match_all_searches(self):
        first_char = self.conf_value[:1]
        last_char = self.conf_value[-1:]
        not_first_char = chr(ord(first_char) + 1)
        not_last_char = chr(ord(last_char) + 1)

        searches = [
            "(!({}={}*))".format(self.conf_attr, not_first_char),
            "(!({}=*{}))".format(self.conf_attr, not_last_char)
        ]
        return searches

    # the following searches will not match against the test object(s). So
    # a user with sufficient rights will see an inverse sub-set of objects.
    # (An unprivileged user would either see all objects on Windows, or no
    # objects on Samba)
    def get_inverse_match_searches(self):
        first_char = self.conf_value[:1]
        last_char = self.conf_value[-1:]
        searches = [
            "(!({}={}*))".format(self.conf_attr, first_char),
            "(!({}=*{}))".format(self.conf_attr, last_char)
        ]
        return searches

    def negative_searches_all_rights(self, total_objects=None):
        expected_results = {}

        if total_objects is None:
            total_objects = self.total_objects

        # these searches should match ALL objects (including the OU)
        for search in self.get_negative_match_all_searches():
            expected_results[search] = total_objects

        # a ! wildcard should only match the objects without the attribute
        search = "(!({}=*))".format(self.conf_attr)
        expected_results[search] = total_objects - self.objects_with_attr

        # whereas the inverse searches should match all objects *except* the
        # one under test
        for search in self.get_inverse_match_searches():
            expected_results[search] = total_objects - 1

        return expected_results

    # Returns the expected negative (i.e. '!') search behaviour when talking to
    # a DC with DC_MODE_RETURN_ALL behaviour, i.e. we assert that users
    # without rights always see ALL objects in '!' searches
    def negative_searches_return_all(self,
                                     has_rights_to=0,
                                     total_objects=None):
        """Asserts user without rights cannot see objects in '!' searches"""
        expected_results = {}

        if total_objects is None:
            total_objects = self.total_objects

        # Windows 'hides' objects by always returning all of them, so negative
        # searches that match all objects will simply return all objects
        for search in self.get_negative_match_all_searches():
            expected_results[search] = total_objects

        # if the search is matching on an inverse subset (everything except the
        # object under test), the
        inverse_searches = self.get_inverse_match_searches()
        inverse_searches += ["(!({}=*))".format(self.conf_attr)]

        for search in inverse_searches:
            expected_results[search] = total_objects - has_rights_to

        return expected_results

    # Returns the expected negative (i.e. '!') search behaviour when talking to
    # a DC with DC_MODE_RETURN_NONE behaviour, i.e. we assert that users
    # without rights cannot see objects in '!' searches at all
    def negative_searches_return_none(self, has_rights_to=0):
        expected_results = {}

        # the 'match-all' searches should only return the objects we have
        # access rights to (if any)
        for search in self.get_negative_match_all_searches():
            expected_results[search] = has_rights_to

        # for inverse matches, we should NOT be told about any objects at all
        inverse_searches = self.get_inverse_match_searches()
        inverse_searches += ["(!({}=*))".format(self.conf_attr)]
        for search in inverse_searches:
            expected_results[search] = 0

        return expected_results

    # Returns the expected negative (i.e. '!') search behaviour. This varies
    # depending on what type of DC we're talking to (i.e. Windows or Samba)
    # and what access rights the user has
    def negative_search_expected_results(self,
                                         has_rights_to,
                                         dc_mode,
                                         total_objects=None):

        if has_rights_to == "all":
            expect_results = self.negative_searches_all_rights(total_objects)

        # if it's a Samba DC, we only expect the 'match-all' searches to return
        # the objects that we have access rights to (all others are hidden).
        # Whereas Windows 'hides' the objects by always returning all of them
        elif dc_mode == DC_MODE_RETURN_NONE:
            expect_results = self.negative_searches_return_none(has_rights_to)
        else:
            expect_results = self.negative_searches_return_all(
                has_rights_to, total_objects)
        return expect_results

    def assert_negative_searches(self,
                                 has_rights_to=0,
                                 dc_mode=DC_MODE_RETURN_NONE,
                                 samdb=None):
        """Asserts user without rights cannot see objects in '!' searches"""

        if samdb is None:
            samdb = self.ldb_user

        # build a dictionary of key=search-expr, value=expected_num assertions
        expected_results = self.negative_search_expected_results(
            has_rights_to, dc_mode)

        for search, expected_num in expected_results.items():
            self.assert_search_result(expected_num, search, samdb)

    def assert_attr_returned(self, expect_attr, samdb, attrs):
        # does a query that should always return a successful result, and
        # checks whether the confidential attribute is present
        res = samdb.search(self.conf_dn,
                           expression="(objectClass=*)",
                           scope=SCOPE_SUBTREE,
                           attrs=attrs)
        self.assertTrue(len(res) == 1)

        attr_returned = False
        for msg in res:
            if self.conf_attr in msg:
                attr_returned = True
        self.assertEqual(expect_attr, attr_returned)

    def assert_attr_visible(self, expect_attr, samdb=None):
        if samdb is None:
            samdb = self.ldb_user

        # sanity-check confidential attribute is/isn't returned as expected
        # based on the filter attributes we ask for
        self.assert_attr_returned(expect_attr, samdb, attrs=None)
        self.assert_attr_returned(expect_attr, samdb, attrs=["*"])
        self.assert_attr_returned(expect_attr, samdb, attrs=[self.conf_attr])

        # filtering on a different attribute should never return the conf_attr
        self.assert_attr_returned(expect_attr=False,
                                  samdb=samdb,
                                  attrs=['name'])

    def assert_attr_visible_to_admin(self):
        # sanity-check the admin user can always see the confidential attribute
        self.assert_conf_attr_searches(has_rights_to="all",
                                       samdb=self.ldb_admin)
        self.assert_negative_searches(has_rights_to="all",
                                      samdb=self.ldb_admin)
        self.assert_attr_visible(expect_attr=True, samdb=self.ldb_admin)
Exemple #39
0
class BaseSortTests(samba.tests.TestCase):
    avoid_tricky_sort = False
    maxDiff = 2000

    def create_user(self, i, n, prefix='sorttest', suffix='', attrs=None,
                    tricky=False):
        name = "%s%d%s" % (prefix, i, suffix)
        user = {
            'cn': name,
            "objectclass": "user",
            'givenName': "abcdefghijklmnopqrstuvwxyz"[i % 26],
            "roomNumber": "%sb\x00c" % (n - i),
            "carLicense": "后来经",
            "employeeNumber": "%s%sx" % (abs(i * (99 - i)), '\n' * (i & 255)),
            "accountExpires": "%s" % (10 ** 9 + 1000000 * i),
            "msTSExpireDate4": "19%02d0101010000.0Z" % (i % 100),
            "flags": str(i * (n - i)),
            "serialNumber": "abc %s%s%s" % ('AaBb |-/'[i & 7],
                                            ' 3z}'[i & 3],
                                            '"@'[i & 1],),
            "comment": "Favourite colour is %d" % (n % (i + 1)),
        }

        if self.avoid_tricky_sort:
            # We are not even going to try passing tests that assume
            # some kind of Unicode awareness.
            for k, v in user.items():
                user[k] = re.sub(r'[^\w,.]', 'X', v)
        else:
            # Add some even trickier ones!
            fiendish_index = i % len(FIENDISH_TESTS)
            user.update({
                # Sort doesn't look past a NUL byte.
                "photo": "\x00%d" % (n - i),
                "audio": "%sn octet string %s%s ♫♬\x00lalala" % ('Aa'[i & 1],
                                                                 chr(i & 255),
                                                                 i),
                "displayNamePrintable": "%d\x00%c" % (i, i & 255),
                "adminDisplayName": "%d\x00b" % (n-i),
                "title": "%d%sb" % (n - i, '\x00' * i),

                # Names that vary only in case. Windows returns
                # equivalent addresses in the order they were put
                # in ('a st', 'A st',...). We don't check that.
                "street": "%s st" % (chr(65 | (i & 14) | ((i & 1) * 32))),

                "streetAddress": FIENDISH_TESTS[fiendish_index],
                "postalAddress": FIENDISH_TESTS[-fiendish_index],
            })

        if attrs is not None:
            user.update(attrs)

        user['dn'] = "cn=%s,%s" % (user['cn'], self.ou)

        self.users.append(user)
        self.ldb.add(user)
        return user

    def setUp(self):
        super(BaseSortTests, self).setUp()
        self.ldb = SamDB(host, credentials=creds,
                         session_info=system_session(lp), lp=lp)

        self.base_dn = self.ldb.domain_dn()
        self.ou = "ou=sort,%s" % self.base_dn
        if False:
            try:
                self.ldb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError, e:
                print "tried deleting %s, got error %s" % (self.ou, e)

        self.ldb.add({
            "dn": self.ou,
            "objectclass": "organizationalUnit"})
        self.users = []
        n = opts.elements
        for i in range(n):
            self.create_user(i, n)

        attrs = set(self.users[0].keys()) - set([
            'objectclass', 'dn'])
        self.binary_sorted_keys = attrs.intersection(['audio',
                                                      'photo',
                                                      "msTSExpireDate4",
                                                      'serialNumber',
                                                      "displayNamePrintable"])

        self.numeric_sorted_keys = attrs.intersection(['flags',
                                                       'accountExpires'])

        self.timestamp_keys = attrs.intersection(['msTSExpireDate4'])

        self.int64_keys = set(['accountExpires'])

        self.locale_sorted_keys = [x for x in attrs if
                                   x not in (self.binary_sorted_keys |
                                             self.numeric_sorted_keys)]

        self.expected_results = {}
        self.expected_results_binary = {}

        for k in self.locale_sorted_keys:
            # Using key=locale.strxfrm fails on \x00
            forward = sorted((norm(x[k]) for x in self.users),
                             cmp=locale.strcoll)
            reverse = list(reversed(forward))
            self.expected_results[k] = (forward, reverse)

        for k in self.binary_sorted_keys:
            forward = sorted((x[k] for x in self.users))
            reverse = list(reversed(forward))
            self.expected_results_binary[k] = (forward, reverse)
            self.expected_results[k] = (forward, reverse)

        # Fix up some because Python gets it wrong, using Schwartzian tramsform
        for k in ('adminDisplayName', 'title', 'streetAddress',
                  'employeeNumber'):
            if k in self.expected_results:
                broken = self.expected_results[k][0]
                tmp = [(x.replace('\x00', ''), x) for x in broken]
                tmp.sort()
                fixed = [x[1] for x in tmp]
                self.expected_results[k] = (fixed, list(reversed(fixed)))
        for k in ('streetAddress', 'postalAddress'):
            if k in self.expected_results:
                c = {}
                for u in self.users:
                    x = u[k]
                    if x in c:
                        c[x] += 1
                        continue
                    c[x] = 1
                fixed = []
                for x in FIENDISH_TESTS:
                    fixed += [norm(x)] * c.get(x, 0)

                rev = list(reversed(fixed))
                self.expected_results[k] = (fixed, rev)
class AuditLogDsdbTests(AuditLogTestBase):
    def setUp(self):
        self.message_type = MSG_DSDB_LOG
        self.event_type = DSDB_EVENT_NAME
        super(AuditLogDsdbTests, self).setUp()

        self.server_ip = os.environ["SERVER_IP"]

        host = "ldap://%s" % os.environ["SERVER"]
        self.ldb = SamDB(url=host,
                         session_info=system_session(),
                         credentials=self.get_credentials(),
                         lp=self.get_loadparm())
        self.server = os.environ["SERVER"]

        # Gets back the basedn
        self.base_dn = self.ldb.domain_dn()

        # Get the old "dSHeuristics" if it was set
        dsheuristics = self.ldb.get_dsheuristics()

        # Set the "dSHeuristics" to activate the correct "userPassword"
        # behaviour
        self.ldb.set_dsheuristics("000000001")

        # Reset the "dSHeuristics" as they were before
        self.addCleanup(self.ldb.set_dsheuristics, dsheuristics)

        # Get the old "minPwdAge"
        minPwdAge = self.ldb.get_minPwdAge()

        # Set it temporarily to "0"
        self.ldb.set_minPwdAge("0")
        self.base_dn = self.ldb.domain_dn()

        # Reset the "minPwdAge" as it was before
        self.addCleanup(self.ldb.set_minPwdAge, minPwdAge)

        # (Re)adds the test user USER_NAME with password USER_PASS
        delete_force(self.ldb, "cn=" + USER_NAME + ",cn=users," + self.base_dn)
        self.ldb.add({
            "dn": "cn=" + USER_NAME + ",cn=users," + self.base_dn,
            "objectclass": "user",
            "sAMAccountName": USER_NAME,
            "userPassword": USER_PASS
        })

    #
    # Discard the messages from the setup code
    #
    def discardSetupMessages(self, dn):
        self.waitForMessages(2, dn=dn)
        self.discardMessages()

    def tearDown(self):
        self.discardMessages()
        super(AuditLogDsdbTests, self).tearDown()

    def haveExpectedTxn(self, expected):
        if self.context["txnMessage"] is not None:
            txn = self.context["txnMessage"]["dsdbTransaction"]
            if txn["transactionId"] == expected:
                return True
        return False

    def waitForTransaction(self, expected, connection=None):
        """Wait for a transaction message to arrive
        The connection is passed through to keep the connection alive
        until all the logging messages have been received.
        """

        self.connection = connection

        start_time = time.time()
        while not self.haveExpectedTxn(expected):
            self.msg_ctx.loop_once(0.1)
            if time.time() - start_time > 1:
                self.connection = None
                return ""

        self.connection = None
        return self.context["txnMessage"]

    def test_net_change_password(self):

        dn = "CN=" + USER_NAME + ",CN=Users," + self.base_dn
        self.discardSetupMessages(dn)

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server)
        password = "******"

        net.change_password(newpassword=password,
                            username=USER_NAME,
                            oldpassword=USER_PASS)

        messages = self.waitForMessages(1, net, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        # We skip the check for self.get_service_description() as this
        # is subject to a race between smbd and the s4 rpc_server code
        # as to which will set the description as it is DCE/RPC over SMB

        self.assertTrue(self.is_guid(audit["transactionId"]))

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["clearTextPassword"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertTrue(actions[0]["redacted"])
        self.assertEquals("replace", actions[0]["action"])

    def test_net_set_password(self):

        dn = "CN=" + USER_NAME + ",CN=Users," + self.base_dn
        self.discardSetupMessages(dn)

        creds = self.insta_creds(template=self.get_credentials())

        lp = self.get_loadparm()
        net = Net(creds, lp, server=self.server)
        password = "******"
        domain = lp.get("workgroup")

        net.set_password(newpassword=password,
                         account_name=USER_NAME,
                         domain_name=domain)
        messages = self.waitForMessages(1, net, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")
        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertEquals(dn, audit["dn"])
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        # We skip the check for self.get_service_description() as this
        # is subject to a race between smbd and the s4 rpc_server code
        # as to which will set the description as it is DCE/RPC over SMB

        self.assertTrue(self.is_guid(audit["transactionId"]))

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["clearTextPassword"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertTrue(actions[0]["redacted"])
        self.assertEquals("replace", actions[0]["action"])

    def test_ldap_change_password(self):

        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        new_password = samba.generate_random_password(32, 32)
        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.ldb.modify_ldif("dn: " + dn + "\n" + "changetype: modify\n" +
                             "delete: userPassword\n" + "userPassword: "******"\n" + "add: userPassword\n" +
                             "userPassword: "******"\n")

        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertEquals(dn, audit["dn"])
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["userPassword"]["actions"]
        self.assertEquals(2, len(actions))
        self.assertTrue(actions[0]["redacted"])
        self.assertEquals("delete", actions[0]["action"])
        self.assertTrue(actions[1]["redacted"])
        self.assertEquals("add", actions[1]["action"])

    def test_ldap_replace_password(self):

        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        new_password = samba.generate_random_password(32, 32)
        self.ldb.modify_ldif("dn: " + dn + "\n" + "changetype: modify\n" +
                             "replace: userPassword\n" + "userPassword: "******"\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")
        self.assertTrue(self.is_guid(audit["transactionId"]))

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["userPassword"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertTrue(actions[0]["redacted"])
        self.assertEquals("replace", actions[0]["action"])

    def test_ldap_add_user(self):

        # The setup code adds a user, so we check for the dsdb events
        # generated by it.
        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        messages = self.waitForMessages(2, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(2, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[1]["dsdbChange"]
        self.assertEquals("Add", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertEquals(dn, audit["dn"])
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")
        self.assertTrue(self.is_guid(audit["sessionId"]))
        self.assertTrue(self.is_guid(audit["transactionId"]))

        attributes = audit["attributes"]
        self.assertEquals(3, len(attributes))

        actions = attributes["objectclass"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        self.assertEquals(1, len(actions[0]["values"]))
        self.assertEquals("user", actions[0]["values"][0]["value"])

        actions = attributes["sAMAccountName"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        self.assertEquals(1, len(actions[0]["values"]))
        self.assertEquals(USER_NAME, actions[0]["values"][0]["value"])

        actions = attributes["userPassword"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        self.assertTrue(actions[0]["redacted"])

    def test_samdb_delete_user(self):

        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        self.ldb.deleteuser(USER_NAME)

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Delete", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        self.assertEquals(0, audit["statusCode"])
        self.assertEquals("Success", audit["status"])
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        transactionId = audit["transactionId"]
        message = self.waitForTransaction(transactionId)
        audit = message["dsdbTransaction"]
        self.assertEquals("commit", audit["action"])
        self.assertTrue(self.is_guid(audit["transactionId"]))
        self.assertTrue(audit["duration"] > 0)

    def test_samdb_delete_non_existent_dn(self):

        DOES_NOT_EXIST = "doesNotExist"
        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        dn = "cn=" + DOES_NOT_EXIST + ",cn=users," + self.base_dn
        try:
            self.ldb.delete(dn)
            self.fail("Exception not thrown")
        except Exception:
            pass

        messages = self.waitForMessages(1)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Delete", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertEquals(ERR_NO_SUCH_OBJECT, audit["statusCode"])
        self.assertEquals("No such object", audit["status"])
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        transactionId = audit["transactionId"]
        message = self.waitForTransaction(transactionId)
        audit = message["dsdbTransaction"]
        self.assertEquals("rollback", audit["action"])
        self.assertTrue(self.is_guid(audit["transactionId"]))
        self.assertTrue(audit["duration"] > 0)

    def test_create_and_delete_secret_over_lsa(self):

        dn = "cn=Test Secret,CN=System," + self.base_dn
        self.discardSetupMessages(dn)

        creds = self.insta_creds(template=self.get_credentials())
        lsa_conn = lsa.lsarpc("ncacn_np:%s" % self.server, self.get_loadparm(),
                              creds)
        lsa_handle = lsa_conn.OpenPolicy2(
            system_name="\\",
            attr=lsa.ObjectAttribute(),
            access_mask=security.SEC_FLAG_MAXIMUM_ALLOWED)
        secret_name = lsa.String()
        secret_name.string = "G$Test"
        lsa_conn.CreateSecret(handle=lsa_handle,
                              name=secret_name,
                              access_mask=security.SEC_FLAG_MAXIMUM_ALLOWED)

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Add", audit["operation"])
        self.assertTrue(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])

        # We skip the check for self.get_service_description() as this
        # is subject to a race between smbd and the s4 rpc_server code
        # as to which will set the description as it is DCE/RPC over SMB

        attributes = audit["attributes"]
        self.assertEquals(2, len(attributes))

        object_class = attributes["objectClass"]
        self.assertEquals(1, len(object_class["actions"]))
        action = object_class["actions"][0]
        self.assertEquals("add", action["action"])
        values = action["values"]
        self.assertEquals(1, len(values))
        self.assertEquals("secret", values[0]["value"])

        cn = attributes["cn"]
        self.assertEquals(1, len(cn["actions"]))
        action = cn["actions"][0]
        self.assertEquals("add", action["action"])
        values = action["values"]
        self.assertEquals(1, len(values))
        self.assertEquals("Test Secret", values[0]["value"])

        #
        # Now delete the secret.
        self.discardMessages()
        h = lsa_conn.OpenSecret(handle=lsa_handle,
                                name=secret_name,
                                access_mask=security.SEC_FLAG_MAXIMUM_ALLOWED)

        lsa_conn.DeleteObject(h)
        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")

        dn = "cn=Test Secret,CN=System," + self.base_dn
        audit = messages[0]["dsdbChange"]
        self.assertEquals("Delete", audit["operation"])
        self.assertTrue(audit["performedAsSystem"])
        self.assertTrue(dn.lower(), audit["dn"].lower())
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])

        # We skip the check for self.get_service_description() as this
        # is subject to a race between smbd and the s4 rpc_server code
        # as to which will set the description as it is DCE/RPC over SMB

    def test_modify(self):

        dn = "cn=" + USER_NAME + ",cn=users," + self.base_dn
        self.discardSetupMessages(dn)

        #
        # Add an attribute value
        #
        self.ldb.modify_ldif("dn: " + dn + "\n" + "changetype: modify\n" +
                             "add: carLicense\n" + "carLicense: license-01\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")

        audit = messages[0]["dsdbChange"]
        self.assertEquals("Modify", audit["operation"])
        self.assertFalse(audit["performedAsSystem"])
        self.assertEquals(dn, audit["dn"])
        self.assertRegexpMatches(audit["remoteAddress"], self.remoteAddress)
        self.assertTrue(self.is_guid(audit["sessionId"]))
        session_id = self.get_session()
        self.assertEquals(session_id, audit["sessionId"])
        service_description = self.get_service_description()
        self.assertEquals(service_description, "LDAP")

        attributes = audit["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(1, len(values))
        self.assertEquals("license-01", values[0]["value"])

        #
        # Add an another value to the attribute
        #
        self.discardMessages()
        self.ldb.modify_ldif("dn: " + dn + "\n" + "changetype: modify\n" +
                             "add: carLicense\n" + "carLicense: license-02\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")
        attributes = messages[0]["dsdbChange"]["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(1, len(values))
        self.assertEquals("license-02", values[0]["value"])

        #
        # Add an another two values to the attribute
        #
        self.discardMessages()
        self.ldb.modify_ldif("dn: " + dn + "\n" + "changetype: modify\n" +
                             "add: carLicense\n" + "carLicense: license-03\n" +
                             "carLicense: license-04\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")
        attributes = messages[0]["dsdbChange"]["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("add", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(2, len(values))
        self.assertEquals("license-03", values[0]["value"])
        self.assertEquals("license-04", values[1]["value"])

        #
        # delete two values to the attribute
        #
        self.discardMessages()
        self.ldb.modify_ldif("dn: " + dn + "\n" + "changetype: delete\n" +
                             "delete: carLicense\n" +
                             "carLicense: license-03\n" +
                             "carLicense: license-04\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")
        attributes = messages[0]["dsdbChange"]["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("delete", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(2, len(values))
        self.assertEquals("license-03", values[0]["value"])
        self.assertEquals("license-04", values[1]["value"])

        #
        # replace two values to the attribute
        #
        self.discardMessages()
        self.ldb.modify_ldif("dn: " + dn + "\n" + "changetype: delete\n" +
                             "replace: carLicense\n" +
                             "carLicense: license-05\n" +
                             "carLicense: license-06\n")

        messages = self.waitForMessages(1, dn=dn)
        print("Received %d messages" % len(messages))
        self.assertEquals(1, len(messages),
                          "Did not receive the expected number of messages")
        attributes = messages[0]["dsdbChange"]["attributes"]
        self.assertEquals(1, len(attributes))
        actions = attributes["carLicense"]["actions"]
        self.assertEquals(1, len(actions))
        self.assertEquals("replace", actions[0]["action"])
        values = actions[0]["values"]
        self.assertEquals(2, len(values))
        self.assertEquals("license-05", values[0]["value"])
        self.assertEquals("license-06", values[1]["value"])
Exemple #41
0
class BaseSortTests(samba.tests.TestCase):
    avoid_tricky_sort = False
    maxDiff = 2000

    def create_user(self, i, n, prefix='sorttest', suffix='', attrs=None,
                    tricky=False):
        name = "%s%d%s" % (prefix, i, suffix)
        user = {
            'cn': name,
            "objectclass": "user",
            'givenName': "abcdefghijklmnopqrstuvwxyz"[i % 26],
            "roomNumber": "%sb\x00c" % (n - i),
            "carLicense": "后来经",
            "employeeNumber": "%s%sx" % (abs(i * (99 - i)), '\n' * (i & 255)),
            "accountExpires": "%s" % (10 ** 9 + 1000000 * i),
            "msTSExpireDate4": "19%02d0101010000.0Z" % (i % 100),
            "flags": str(i * (n - i)),
            "serialNumber": "abc %s%s%s" % ('AaBb |-/'[i & 7],
                                            ' 3z}'[i & 3],
                                            '"@'[i & 1],),
            "comment": "Favourite colour is %d" % (n % (i + 1)),
        }

        if self.avoid_tricky_sort:
            # We are not even going to try passing tests that assume
            # some kind of Unicode awareness.
            for k, v in user.items():
                user[k] = re.sub(r'[^\w,.]', 'X', v)
        else:
            # Add some even trickier ones!
            fiendish_index = i % len(FIENDISH_TESTS)
            user.update({
                # Sort doesn't look past a NUL byte.
                "photo": "\x00%d" % (n - i),
                "audio": "%sn octet string %s%s ♫♬\x00lalala" % ('Aa'[i & 1],
                                                                 chr(i & 255),
                                                                 i),
                "displayNamePrintable": "%d\x00%c" % (i, i & 255),
                "adminDisplayName": "%d\x00b" % (n-i),
                "title": "%d%sb" % (n - i, '\x00' * i),

                # Names that vary only in case. Windows returns
                # equivalent addresses in the order they were put
                # in ('a st', 'A st',...). We don't check that.
                "street": "%s st" % (chr(65 | (i & 14) | ((i & 1) * 32))),

                "streetAddress": FIENDISH_TESTS[fiendish_index],
                "postalAddress": FIENDISH_TESTS[-fiendish_index],
            })

        if attrs is not None:
            user.update(attrs)

        user['dn'] = "cn=%s,%s" % (user['cn'], self.ou)

        self.users.append(user)
        self.ldb.add(user)
        return user

    def setUp(self):
        super(BaseSortTests, self).setUp()
        self.ldb = SamDB(host, credentials=creds,
                         session_info=system_session(lp), lp=lp)

        self.base_dn = self.ldb.domain_dn()
        self.ou = "ou=sort,%s" % self.base_dn
        if False:
            try:
                self.ldb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError, e:
                print "tried deleting %s, got error %s" % (self.ou, e)

        self.ldb.add({
            "dn": self.ou,
            "objectclass": "organizationalUnit"})
        self.users = []
        n = opts.elements
        for i in range(n):
            self.create_user(i, n)

        attrs = set(self.users[0].keys()) - set([
            'objectclass', 'dn'])
        self.binary_sorted_keys = attrs.intersection(['audio',
                                                      'photo',
                                                      "msTSExpireDate4",
                                                      'serialNumber',
                                                      "displayNamePrintable"])

        self.numeric_sorted_keys = attrs.intersection(['flags',
                                                       'accountExpires'])

        self.timestamp_keys = attrs.intersection(['msTSExpireDate4'])

        self.int64_keys = set(['accountExpires'])

        self.locale_sorted_keys = [x for x in attrs if
                                   x not in (self.binary_sorted_keys |
                                             self.numeric_sorted_keys)]

        self.expected_results = {}
        self.expected_results_binary = {}

        for k in self.locale_sorted_keys:
            # Using key=locale.strxfrm fails on \x00
            forward = sorted((norm(x[k]) for x in self.users),
                             cmp=locale.strcoll)
            reverse = list(reversed(forward))
            self.expected_results[k] = (forward, reverse)

        for k in self.binary_sorted_keys:
            forward = sorted((x[k] for x in self.users))
            reverse = list(reversed(forward))
            self.expected_results_binary[k] = (forward, reverse)
            self.expected_results[k] = (forward, reverse)

        # Fix up some because Python gets it wrong, using Schwartzian tramsform
        for k in ('adminDisplayName', 'title', 'streetAddress',
                  'employeeNumber'):
            if k in self.expected_results:
                broken = self.expected_results[k][0]
                tmp = [(x.replace('\x00', ''), x) for x in broken]
                tmp.sort()
                fixed = [x[1] for x in tmp]
                self.expected_results[k] = (fixed, list(reversed(fixed)))
        for k in ('streetAddress', 'postalAddress'):
            if k in self.expected_results:
                c = {}
                for u in self.users:
                    x = u[k]
                    if x in c:
                        c[x] += 1
                        continue
                    c[x] = 1
                fixed = []
                for x in FIENDISH_TESTS:
                    fixed += [norm(x)] * c[x]

                rev = list(reversed(fixed))
                self.expected_results[k] = (fixed, rev)
class LATests(samba.tests.TestCase):

    def setUp(self):
        super(LATests, self).setUp()
        self.samdb = SamDB(host, credentials=creds,
                           session_info=system_session(lp), lp=lp)

        self.base_dn = self.samdb.domain_dn()
        self.ou = "OU=la,%s" % self.base_dn
        if opts.delete_in_setup:
            try:
                self.samdb.delete(self.ou, ['tree_delete:1'])
            except ldb.LdbError as e:
                print("tried deleting %s, got error %s" % (self.ou, e))
        self.samdb.add({'objectclass': 'organizationalUnit',
                        'dn': self.ou})

    def tearDown(self):
        super(LATests, self).tearDown()
        if not opts.no_cleanup:
            self.samdb.delete(self.ou, ['tree_delete:1'])

    def add_object(self, cn, objectclass, more_attrs={}):
        dn = "CN=%s,%s" % (cn, self.ou)
        attrs = {'cn': cn,
                 'objectclass': objectclass,
                 'dn': dn}
        attrs.update(more_attrs)
        self.samdb.add(attrs)

        return dn

    def add_objects(self, n, objectclass, prefix=None, more_attrs={}):
        if prefix is None:
            prefix = objectclass
        dns = []
        for i in range(n):
            dns.append(self.add_object("%s%d" % (prefix, i + 1),
                                       objectclass,
                                       more_attrs=more_attrs))
        return dns

    def add_linked_attribute(self, src, dest, attr='member',
                             controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_ADD, attr)
        self.samdb.modify(m, controls=controls)

    def remove_linked_attribute(self, src, dest, attr='member',
                                controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_DELETE, attr)
        self.samdb.modify(m, controls=controls)

    def replace_linked_attribute(self, src, dest, attr='member',
                                 controls=None):
        m = ldb.Message()
        m.dn = ldb.Dn(self.samdb, src)
        m[attr] = ldb.MessageElement(dest, ldb.FLAG_MOD_REPLACE, attr)
        self.samdb.modify(m, controls=controls)

    def attr_search(self, obj, attr, scope=ldb.SCOPE_BASE, **controls):
        if opts.no_reveal_internals:
            if 'reveal_internals' in controls:
                del controls['reveal_internals']

        controls = ['%s:%d' % (k, int(v)) for k, v in controls.items()]

        res = self.samdb.search(obj,
                                scope=scope,
                                attrs=[attr],
                                controls=controls)
        return res

    def assert_links(self, obj, expected, attr, msg='', **kwargs):
        res = self.attr_search(obj, attr, **kwargs)

        if len(expected) == 0:
            if attr in res[0]:
                self.fail("found attr '%s' in %s" % (attr, res[0]))
            return

        try:
            results = list([x[attr] for x in res][0])
        except KeyError:
            self.fail("missing attr '%s' on %s" % (attr, obj))

        expected = sorted(expected)
        results = sorted(results)

        if expected != results:
            print(msg)
            print("expected %s" % expected)
            print("received %s" % results)

        self.assertEqual(results, expected)

    def assert_back_links(self, obj, expected, attr='memberOf', **kwargs):
        self.assert_links(obj, expected, attr=attr,
                          msg='back links do not match', **kwargs)

    def assert_forward_links(self, obj, expected, attr='member', **kwargs):
        self.assert_links(obj, expected, attr=attr,
                          msg='forward links do not match', **kwargs)

    def get_object_guid(self, dn):
        res = self.samdb.search(dn,
                                scope=ldb.SCOPE_BASE,
                                attrs=['objectGUID'])
        return str(misc.GUID(res[0]['objectGUID'][0]))

    def assertRaisesLdbError(self, errcode, msg, f, *args, **kwargs):
        """Assert a function raises a particular LdbError."""
        try:
            f(*args, **kwargs)
        except ldb.LdbError as e:
            (num, msg) = e.args
            if num != errcode:
                lut = {v: k for k, v in vars(ldb).items()
                       if k.startswith('ERR_') and isinstance(v, int)}
                self.fail("%s, expected "
                          "LdbError %s, (%d) "
                          "got %s (%d)" % (msg,
                                           lut.get(errcode), errcode,
                                           lut.get(num), num))
        else:
            lut = {v: k for k, v in vars(ldb).items()
                   if k.startswith('ERR_') and isinstance(v, int)}
            self.fail("%s, expected "
                      "LdbError %s, (%d) "
                      "but we got success" % (msg, lut.get(errcode), errcode))

    def _test_la_backlinks(self, reveal=False):
        tag = 'backlinks'
        kwargs = {}
        if reveal:
            tag += '_reveal'
            kwargs = {'reveal_internals': 0}

        u1, u2 = self.add_objects(2, 'user', 'u_%s' % tag)
        g1, g2 = self.add_objects(2, 'group', 'g_%s' % tag)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.assert_back_links(u1, [g1, g2], **kwargs)
        self.assert_back_links(u2, [g2], **kwargs)

    def test_la_backlinks(self):
        self._test_la_backlinks()

    def test_la_backlinks_reveal(self):
        if opts.no_reveal_internals:
            print('skipping because --no-reveal-internals')
            return
        self._test_la_backlinks(True)

    def _test_la_backlinks_delete_group(self, reveal=False):
        tag = 'del_group'
        kwargs = {}
        if reveal:
            tag += '_reveal'
            kwargs = {'reveal_internals': 0}

        u1, u2 = self.add_objects(2, 'user', 'u_' + tag)
        g1, g2 = self.add_objects(2, 'group', 'g_' + tag)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.samdb.delete(g2, ['tree_delete:1'])

        self.assert_back_links(u1, [g1], **kwargs)
        self.assert_back_links(u2, set(), **kwargs)

    def test_la_backlinks_delete_group(self):
        self._test_la_backlinks_delete_group()

    def test_la_backlinks_delete_group_reveal(self):
        if opts.no_reveal_internals:
            print('skipping because --no-reveal-internals')
            return
        self._test_la_backlinks_delete_group(True)

    def test_links_all_delete_group(self):
        u1, u2 = self.add_objects(2, 'user', 'u_all_del_group')
        g1, g2 = self.add_objects(2, 'group', 'g_all_del_group')
        g2guid = self.get_object_guid(g2)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.samdb.delete(g2)
        self.assert_back_links(u1, [g1], show_deleted=1, show_recycled=1,
                               show_deactivated_link=0)
        self.assert_back_links(u2, set(), show_deleted=1, show_recycled=1,
                               show_deactivated_link=0)
        self.assert_forward_links(g1, [u1], show_deleted=1, show_recycled=1,
                                  show_deactivated_link=0)
        self.assert_forward_links('<GUID=%s>' % g2guid,
                                  [], show_deleted=1, show_recycled=1,
                                  show_deactivated_link=0)

    def test_links_all_delete_group_reveal(self):
        u1, u2 = self.add_objects(2, 'user', 'u_all_del_group_reveal')
        g1, g2 = self.add_objects(2, 'group', 'g_all_del_group_reveal')
        g2guid = self.get_object_guid(g2)

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.samdb.delete(g2)
        self.assert_back_links(u1, [g1], show_deleted=1, show_recycled=1,
                               show_deactivated_link=0,
                               reveal_internals=0)
        self.assert_back_links(u2, set(), show_deleted=1, show_recycled=1,
                               show_deactivated_link=0,
                               reveal_internals=0)
        self.assert_forward_links(g1, [u1], show_deleted=1, show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)
        self.assert_forward_links('<GUID=%s>' % g2guid,
                                  [], show_deleted=1, show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)

    def test_la_links_delete_link(self):
        u1, u2 = self.add_objects(2, 'user', 'u_del_link')
        g1, g2 = self.add_objects(2, 'group', 'g_del_link')

        res = self.samdb.search(g1, scope=ldb.SCOPE_BASE,
                                attrs=['uSNChanged'])
        old_usn1 = int(res[0]['uSNChanged'][0])

        self.add_linked_attribute(g1, u1)

        res = self.samdb.search(g1, scope=ldb.SCOPE_BASE,
                                attrs=['uSNChanged'])
        new_usn1 = int(res[0]['uSNChanged'][0])

        self.assertNotEqual(old_usn1, new_usn1, "USN should have incremented")

        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        res = self.samdb.search(g2, scope=ldb.SCOPE_BASE,
                                attrs=['uSNChanged'])
        old_usn2 = int(res[0]['uSNChanged'][0])

        self.remove_linked_attribute(g2, u1)

        res = self.samdb.search(g2, scope=ldb.SCOPE_BASE,
                                attrs=['uSNChanged'])
        new_usn2 = int(res[0]['uSNChanged'][0])

        self.assertNotEqual(old_usn2, new_usn2, "USN should have incremented")

        self.assert_forward_links(g1, [u1])
        self.assert_forward_links(g2, [u2])

        self.add_linked_attribute(g2, u1)
        self.assert_forward_links(g2, [u1, u2])
        self.remove_linked_attribute(g2, u2)
        self.assert_forward_links(g2, [u1])
        self.remove_linked_attribute(g2, u1)
        self.assert_forward_links(g2, [])
        self.remove_linked_attribute(g1, [])
        self.assert_forward_links(g1, [])

        # removing a duplicate link in the same message should fail
        self.add_linked_attribute(g2, [u1, u2])
        self.assertRaises(ldb.LdbError,
                          self.remove_linked_attribute,g2, [u1, u1])

    def _test_la_links_delete_link_reveal(self):
        u1, u2 = self.add_objects(2, 'user', 'u_del_link_reveal')
        g1, g2 = self.add_objects(2, 'group', 'g_del_link_reveal')

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.remove_linked_attribute(g2, u1)

        self.assert_forward_links(g2, [u1, u2], show_deleted=1,
                                  show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0
        )

    def test_la_links_delete_link_reveal(self):
        if opts.no_reveal_internals:
            print('skipping because --no-reveal-internals')
            return
        self._test_la_links_delete_link_reveal()

    def test_la_links_delete_user(self):
        u1, u2 = self.add_objects(2, 'user', 'u_del_user')
        g1, g2 = self.add_objects(2, 'group', 'g_del_user')

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        res = self.samdb.search(g1, scope=ldb.SCOPE_BASE,
                                attrs=['uSNChanged'])
        old_usn1 = int(res[0]['uSNChanged'][0])

        res = self.samdb.search(g2, scope=ldb.SCOPE_BASE,
                                attrs=['uSNChanged'])
        old_usn2 = int(res[0]['uSNChanged'][0])

        self.samdb.delete(u1)

        self.assert_forward_links(g1, [])
        self.assert_forward_links(g2, [u2])

        res = self.samdb.search(g1, scope=ldb.SCOPE_BASE,
                                attrs=['uSNChanged'])
        new_usn1 = int(res[0]['uSNChanged'][0])

        res = self.samdb.search(g2, scope=ldb.SCOPE_BASE,
                                attrs=['uSNChanged'])
        new_usn2 = int(res[0]['uSNChanged'][0])

        # Assert the USN on the alternate object is unchanged
        self.assertEqual(old_usn1, new_usn1)
        self.assertEqual(old_usn2, new_usn2)

    def test_la_links_delete_user_reveal(self):
        u1, u2 = self.add_objects(2, 'user', 'u_del_user_reveal')
        g1, g2 = self.add_objects(2, 'group', 'g_del_user_reveal')

        self.add_linked_attribute(g1, u1)
        self.add_linked_attribute(g2, u1)
        self.add_linked_attribute(g2, u2)

        self.samdb.delete(u1)

        self.assert_forward_links(g2, [u2],
                                  show_deleted=1, show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)
        self.assert_forward_links(g1, [],
                                  show_deleted=1, show_recycled=1,
                                  show_deactivated_link=0,
                                  reveal_internals=0)

    def test_multiple_links(self):
        u1, u2, u3, u4 = self.add_objects(4, 'user', 'u_multiple_links')
        g1, g2, g3, g4 = self.add_objects(4, 'group', 'g_multiple_links')

        self.add_linked_attribute(g1, [u1, u2, u3, u4])
        self.add_linked_attribute(g2, [u3, u1])
        self.add_linked_attribute(g3, u2)

        self.assertRaisesLdbError(ldb.ERR_ENTRY_ALREADY_EXISTS,
                                  "adding duplicate values",
                                  self.add_linked_attribute, g2,
                                  [u1, u2, u3, u2])

        self.assert_forward_links(g1, [u1, u2, u3, u4])
        self.assert_forward_links(g2, [u3, u1])
        self.assert_forward_links(g3, [u2])
        self.assert_back_links(u1, [g2, g1])
        self.assert_back_links(u2, [g3, g1])
        self.assert_back_links(u3, [g2, g1])
        self.assert_back_links(u4, [g1])

        self.remove_linked_attribute(g2, [u1, u3])
        self.remove_linked_attribute(g1, [u1, u3])

        self.assert_forward_links(g1, [u2, u4])
        self.assert_forward_links(g2, [])
        self.assert_forward_links(g3, [u2])
        self.assert_back_links(u1, [])
        self.assert_back_links(u2, [g3, g1])
        self.assert_back_links(u3, [])
        self.assert_back_links(u4, [g1])

        self.add_linked_attribute(g1, [u1, u3])
        self.add_linked_attribute(g2, [u3, u1])
        self.add_linked_attribute(g3, [u1, u3])

        self.assert_forward_links(g1, [u1, u2, u3, u4])
        self.assert_forward_links(g2, [u1, u3])
        self.assert_forward_links(g3, [u1, u2, u3])
        self.assert_back_links(u1, [g1, g2, g3])
        self.assert_back_links(u2, [g3, g1])
        self.assert_back_links(u3, [g3, g2, g1])
        self.assert_back_links(u4, [g1])

    def test_la_links_replace(self):
        u1, u2, u3, u4 = self.add_objects(4, 'user', 'u_replace')
        g1, g2, g3, g4 = self.add_objects(4, 'group', 'g_replace')

        self.add_linked_attribute(g1, [u1, u2])
        self.add_linked_attribute(g2, [u1, u3])
        self.add_linked_attribute(g3, u1)

        self.replace_linked_attribute(g1, [u2])
        self.replace_linked_attribute(g2, [u2, u3])
        self.replace_linked_attribute(g3, [u1, u3])
        self.replace_linked_attribute(g4, [u4])

        self.assert_forward_links(g1, [u2])
        self.assert_forward_links(g2, [u3, u2])
        self.assert_forward_links(g3, [u3, u1])
        self.assert_forward_links(g4, [u4])
        self.assert_back_links(u1, [g3])
        self.assert_back_links(u2, [g1, g2])
        self.assert_back_links(u3, [g2, g3])
        self.assert_back_links(u4, [g4])

        self.replace_linked_attribute(g1, [u1, u2, u3])
        self.replace_linked_attribute(g2, [u1])
        self.replace_linked_attribute(g3, [u2])
        self.replace_linked_attribute(g4, [])

        self.assert_forward_links(g1, [u1, u2, u3])
        self.assert_forward_links(g2, [u1])
        self.assert_forward_links(g3, [u2])
        self.assert_forward_links(g4, [])
        self.assert_back_links(u1, [g1, g2])
        self.assert_back_links(u2, [g1, g3])
        self.assert_back_links(u3, [g1])
        self.assert_back_links(u4, [])

        self.assertRaisesLdbError(ldb.ERR_ENTRY_ALREADY_EXISTS,
                                  "replacing duplicate values",
                                  self.replace_linked_attribute, g2,
                                  [u1, u2, u3, u2])


    def test_la_links_replace2(self):
        users = self.add_objects(12, 'user', 'u_replace2')
        g1, = self.add_objects(1, 'group', 'g_replace2')

        self.add_linked_attribute(g1, users[:6])
        self.assert_forward_links(g1, users[:6])
        self.replace_linked_attribute(g1, users)
        self.assert_forward_links(g1, users)
        self.replace_linked_attribute(g1, users[6:])
        self.assert_forward_links(g1, users[6:])
        self.remove_linked_attribute(g1, users[6:9])
        self.assert_forward_links(g1, users[9:])
        self.remove_linked_attribute(g1, users[9:])
        self.assert_forward_links(g1, [])

    def test_la_links_permutations(self):
        """Make sure the order in which we add links doesn't matter."""
        users = self.add_objects(3, 'user', 'u_permutations')
        groups = self.add_objects(6, 'group', 'g_permutations')

        for g, p in zip(groups, itertools.permutations(users)):
            self.add_linked_attribute(g, p)

        # everyone should be in every group
        for g in groups:
            self.assert_forward_links(g, users)

        for u in users:
            self.assert_back_links(u, groups)

        for g, p in zip(groups[::-1], itertools.permutations(users)):
            self.replace_linked_attribute(g, p)

        for g in groups:
            self.assert_forward_links(g, users)

        for u in users:
            self.assert_back_links(u, groups)

        for g, p in zip(groups, itertools.permutations(users)):
            self.remove_linked_attribute(g, p)

        for g in groups:
            self.assert_forward_links(g, [])

        for u in users:
            self.assert_back_links(u, [])

    def test_la_links_relaxed(self):
        """Check that the relax control doesn't mess with linked attributes."""
        relax_control = ['relax:0']

        users = self.add_objects(10, 'user', 'u_relax')
        groups = self.add_objects(3, 'group', 'g_relax',
                                  more_attrs={'member': users[:2]})
        g_relax1, g_relax2, g_uptight = groups

        # g_relax1 has all users added at once
        # g_relax2 gets them one at a time in reverse order
        # g_uptight never relaxes

        self.add_linked_attribute(g_relax1, users[2:5], controls=relax_control)

        for u in reversed(users[2:5]):
            self.add_linked_attribute(g_relax2, u, controls=relax_control)
            self.add_linked_attribute(g_uptight, u)

        for g in groups:
            self.assert_forward_links(g, users[:5])

            self.add_linked_attribute(g, users[5:7])
            self.assert_forward_links(g, users[:7])

            for u in users[7:]:
                self.add_linked_attribute(g, u)

            self.assert_forward_links(g, users)

        for u in users:
            self.assert_back_links(u, groups)

        # try some replacement permutations
        import random
        random.seed(1)
        users2 = users[:]
        for i in range(5):
            random.shuffle(users2)
            self.replace_linked_attribute(g_relax1, users2,
                                          controls=relax_control)

            self.assert_forward_links(g_relax1, users)

        for i in range(5):
            random.shuffle(users2)
            self.remove_linked_attribute(g_relax2, users2,
                                         controls=relax_control)
            self.remove_linked_attribute(g_uptight, users2)

            self.replace_linked_attribute(g_relax1, [], controls=relax_control)

            random.shuffle(users2)
            self.add_linked_attribute(g_relax2, users2,
                                      controls=relax_control)
            self.add_linked_attribute(g_uptight, users2)
            self.replace_linked_attribute(g_relax1, users2,
                                          controls=relax_control)

            self.assert_forward_links(g_relax1, users)
            self.assert_forward_links(g_relax2, users)
            self.assert_forward_links(g_uptight, users)

        for u in users:
            self.assert_back_links(u, groups)

    def test_add_all_at_once(self):
        """All these other tests are creating linked attributes after the
        objects are there. We want to test creating them all at once
        using LDIF.
        """
        users = self.add_objects(7, 'user', 'u_all_at_once')
        g1, g3 = self.add_objects(2, 'group', 'g_all_at_once',
                                  more_attrs={'member': users})
        (g2,) = self.add_objects(1, 'group', 'g_all_at_once2',
                                 more_attrs={'member': users[:5]})

        self.assertRaisesLdbError(ldb.ERR_ENTRY_ALREADY_EXISTS,
                                  "adding multiple duplicate values",
                                  self.add_objects, 1, 'group',
                                  'g_with_duplicate_links',
                                  more_attrs={'member': users[:5] + users[1:2]})

        self.assert_forward_links(g1, users)
        self.assert_forward_links(g2, users[:5])
        self.assert_forward_links(g3, users)
        for u in users[:5]:
            self.assert_back_links(u, [g1, g2, g3])
        for u in users[5:]:
            self.assert_back_links(u, [g1, g3])

        self.remove_linked_attribute(g2, users[0])
        self.remove_linked_attribute(g2, users[1])
        self.add_linked_attribute(g2, users[1])
        self.add_linked_attribute(g2, users[5])
        self.add_linked_attribute(g2, users[6])

        self.assert_forward_links(g1, users)
        self.assert_forward_links(g2, users[1:])

        for u in users[1:]:
            self.remove_linked_attribute(g2, u)
        self.remove_linked_attribute(g1, users)

        for u in users:
            self.samdb.delete(u)

        self.assert_forward_links(g1, [])
        self.assert_forward_links(g2, [])
        self.assert_forward_links(g3, [])

    def test_one_way_attributes(self):
        e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',
                                  'e_one_way')
        guid = self.get_object_guid(e2)

        self.add_linked_attribute(e1, e2, attr="addressBookRoots")
        self.assert_forward_links(e1, [e2], attr='addressBookRoots')

        self.samdb.delete(e2)

        res = self.samdb.search("<GUID=%s>" % guid,
                                scope=ldb.SCOPE_BASE,
                                controls=['show_deleted:1',
                                          'show_recycled:1'])

        new_dn = str(res[0].dn)
        self.assert_forward_links(e1, [new_dn], attr='addressBookRoots')
        self.assert_forward_links(e1, [new_dn],
                                  attr='addressBookRoots',
                                  show_deactivated_link=0)

    def test_one_way_attributes_delete_link(self):
        e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',
                                  'e_one_way')
        guid = self.get_object_guid(e2)

        self.add_linked_attribute(e1, e2, attr="addressBookRoots")
        self.assert_forward_links(e1, [e2], attr='addressBookRoots')

        self.remove_linked_attribute(e1, e2, attr="addressBookRoots")

        self.assert_forward_links(e1, [], attr='addressBookRoots')
        self.assert_forward_links(e1, [], attr='addressBookRoots',
                                  show_deactivated_link=0)

    def test_pretend_one_way_attributes(self):
        e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',
                                  'e_one_way')
        guid = self.get_object_guid(e2)

        self.add_linked_attribute(e1, e2, attr="addressBookRoots2")
        self.assert_forward_links(e1, [e2], attr='addressBookRoots2')

        self.samdb.delete(e2)
        res = self.samdb.search("<GUID=%s>" % guid,
                                scope=ldb.SCOPE_BASE,
                                controls=['show_deleted:1',
                                          'show_recycled:1'])

        new_dn = str(res[0].dn)

        self.assert_forward_links(e1, [], attr='addressBookRoots2')
        self.assert_forward_links(e1, [], attr='addressBookRoots2',
                                  show_deactivated_link=0)

    def test_pretend_one_way_attributes_delete_link(self):
        e1, e2 = self.add_objects(2, 'msExchConfigurationContainer',
                                  'e_one_way')
        guid = self.get_object_guid(e2)

        self.add_linked_attribute(e1, e2, attr="addressBookRoots2")
        self.assert_forward_links(e1, [e2], attr='addressBookRoots2')

        self.remove_linked_attribute(e1, e2, attr="addressBookRoots2")

        self.assert_forward_links(e1, [], attr='addressBookRoots2')
        self.assert_forward_links(e1, [], attr='addressBookRoots2',
                                  show_deactivated_link=0)
        dn = 'samAccountName=dns-%s,CN=Principals' % names.hostname
        msg = secretsdb.search(expression='(dn=%s)' % dn, attrs=['secret'])
        dnssecret = msg[0]['secret'][0]
    except Exception:
        print "Adding dns-%s account" % names.hostname

        try:
            msg = samdb.search(base=names.domaindn,
                               scope=samba.ldb.SCOPE_DEFAULT,
                               expression='(sAMAccountName=dns-%s)' %
                               (names.hostname),
                               attrs=['clearTextPassword'])
            if msg:
                print "removing sAMAccountName=dns-%s" % (names.hostname)
                dn = msg[0].dn
                samdb.delete(dn)
        except Exception:
            print "exception while removing sAMAccountName=dns-%s" % (
                names.hostname)
            pass

        setup_add_ldif(
            secretsdb, setup_path("secrets_dns.ldif"), {
                "REALM":
                names.realm,
                "DNSDOMAIN":
                names.dnsdomain,
                "DNS_KEYTAB":
                dns_keytab_path,
                "DNSPASS_B64":
                b64encode(dnspass),